{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7e7a43de",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\env\\Anaconda3\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "D:\\env\\Anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.QVLO2T66WEPI7JZ63PS3HMOHFEY472BC.gfortran-win_amd64.dll\n",
      "D:\\env\\Anaconda3\\lib\\site-packages\\numpy\\.libs\\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "import torch.utils.data as Data\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from torchsummary import summary\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e443697a",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# Pytorch与Numpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a2dfa140",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数组的数据格式: <class 'numpy.ndarray'>\n",
      "张量的数据格式: <class 'torch.Tensor'>\n",
      "tensor([[1, 2, 3],\n",
      "        [4, 5, 6],\n",
      "        [7, 8, 9]], dtype=torch.int32)\n",
      "tensor([[0.2157, 0.4747, 0.4749],\n",
      "        [0.1501, 0.8676, 0.8471],\n",
      "        [0.4970, 0.0711, 0.4230],\n",
      "        [0.9386, 0.3070, 0.3278],\n",
      "        [0.5340, 0.6519, 0.6013]])\n",
      "tensor([[0.2157, 0.4747, 0.4749],\n",
      "        [0.1501, 0.8676, 0.8471],\n",
      "        [0.4970, 0.0711, 0.4230],\n",
      "        [0.9386, 0.3070, 0.3278],\n",
      "        [0.5340, 0.6519, 0.6013]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# numpy转tensor\n",
    "data_list = np.array([[1,2,3],[4,5,6],[7,8,9]])\n",
    "data_tensor = torch.from_numpy(data_list)\n",
    "\n",
    "print('数组的数据格式: {}'.format(type(data_list)))\n",
    "print('张量的数据格式: {}'.format(type(data_tensor)))\n",
    "print(data_tensor)\n",
    "\n",
    "# 直接生成\n",
    "x = torch.Tensor(5, 3)\n",
    "y = torch.rand(5, 3)\n",
    "# 相加\n",
    "z = x+y\n",
    "print(z)\n",
    "\n",
    "# CPU与GPU的转换\n",
    "if torch.cuda.is_available():\n",
    "    x_gpu = x.cuda()\n",
    "    y_gpu = y.cuda()\n",
    "    z_gpu = x_gpu + y_gpu\n",
    "    print(z_gpu)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ff6a597",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# 自动求导与梯度计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c3a428c4",
   "metadata": {
    "hidden": true,
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x18388303ac0>]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAT3klEQVR4nO3db6icZ5nH8e+Vc3S1FckpTUvOabupENREkEqQrgUpniPrarF9U6kkIetWAue4WkXQum/6SugLEX2xCYSqDXaollpoWcTVHBXZF9vd1ArbNCst/qn5Y3Nck1VktUlz7YuZo6dp/s3zPDPz/Pl+oMzMM/Oc5xqSXueXe+77nshMJEntsm7SBUiSqmdzl6QWsrlLUgvZ3CWphWzuktRC05MuAODqq6/OTZs2TboMSWqUp5566jeZueF8z9WiuW/atImDBw9OugxJapSI+OWFnnNYRpJayOYuSS1kc5ekFrK5S1IL2dwlqYVs7pI0AUtLMD0NEf3bpaVqf34tpkJKUlcsLcHeva889vLLfzm2Z0811zG5S9KYzMy8urGvtW9fddeyuUvSiC0t9YdfTp26+Otefrm6azosI0kjNDNz6aa+amqquuua3CVpBC43ra+1e3d11ze5S1KFej3YuROG/QbT2dnqPkwFm7skVWZhAZaXhz9vyxY4dKjaWhyWkaSSej1Yt274xh4BDz1UfWMHk7sklVKntL7WJZN7RHw1Ik5ExDNrjl0VEd+LiOcGtzNrnvtcRDwfET+NiL8dVeGSNEl1TOtrXc6wzIPA+845di+wnJmbgeXBYyJiC3AXsHVwzp6IqHByjyRN3tatsGPH8B+abtkCZ8/C9u2jqWutSzb3zPwR8NtzDt8O7B/c3w/cseb4NzLzT5n5c+B54J3VlCpJk9Xr9ZP3s88Od9640vpaRcfcr83M4wCZeTwirhkcnwP+fc3rjgyOvUpE7AZ2A9xwww0Fy5Ck8di6dfimDqMfW7+QqmfLxHmOnfcfLpm5LzO3Zea2DRvO+/2ukjRxq4uRmpDW1yqa3F+MiI2D1L4RODE4fgS4fs3rrgOOlSlQkiah6GIkmFxaX6tocn8C2DW4vwt4fM3xuyLiryLiRmAz8B/lSpSk8VpYKPaB6dTUZNP6WpdM7hHxMHArcHVEHAHuA+4HHomIu4EXgDsBMvNQRDwCPAucAT6WmRXucyZJo1Mmrc/Pw4ED1ddU1CWbe2Z++AJPzV/g9Z8HPl+mKEkat6IfmE5Nwf7945neOAy3H5DUaUWnNwIsLsKZM/Vr7OD2A5I6rGhan52Fo0err6dKJndJnVN2MVLdGzuY3CV1TNMWIxVlcpfUCUXTep2mNw7D5C6p9Yqm9bpNbxyGyV1Sa5VN601t7GByl9RSXUzra5ncJbVK0Y2+2pDW1zK5S2qNuTk4VmCrwsVF2LOn+nomyeYuqfGWlmDv3uHPW78eTp6svJxacFhGUqPNzRVr7IuL7W3sYHKX1FCm9YszuUtqlNXpjab1izO5S2qMNm/0VTWTu6TaK7stb9caO5jcJdWcab0Yk7ukWiq6GAm6m9bXMrlLqp2ii5Gati3vKJncJdXGaloftrGvfomGjf0vTO6SamFmBk6dGv68Nm4dUAWbu6SJcjHSaDgsI2liZmZcjDQqNndJY7c6tj7sMMzsLGQ6DHM5bO6SxqbXg3Xriqf1rk9vHIZj7pLGYmEBlpeHP8/pjcWY3CWN1GpaH7axO72xHJO7pJExrU+OyV1S5Uzrk2dyl1Spoht9mdarVSq5R8SnIuJQRDwTEQ9HxOsi4qqI+F5EPDe4namqWEn1VXRbXtP6aBRu7hExB3wC2JaZbwOmgLuAe4HlzNwMLA8eS2qxrVthx47hz9uyBc6ehe3bq6+p68qOuU8Dr4+IaeAK4BhwO7B/8Px+4I6S15BUU0W35TWtj17hMffMPBoRXwBeAP4P+G5mfjcirs3M44PXHI+Ia853fkTsBnYD3HDDDUXLkDQBvR7s3NlfLTosx9bHo8ywzAz9lH4jMAtcGRGX/Q+zzNyXmdsyc9uGDRuKliFpzBYW+kMwwzb2qSnT+jiVGZZZAH6emSuZeRp4DHgX8GJEbAQY3J4oX6akSSs6vRFgfh7OnHFsfZzKNPcXgJsj4oqICGAeOAw8AewavGYX8Hi5EiVNWtm0fuDAaOrShZUZc38yIh4FfgycAZ4G9gFvAB6JiLvp/wK4s4pCJY1fmbH1+Xmb+iSVWsSUmfcB951z+E/0U7ykBiu6GGl6Gh580CGYSXP7AUmvUHYx0unTNvY6cPsBSX/m1gHtYXKX5NYBLWRylzrOtN5OJnepo4qmdRcjNYPJXeqgomnd6Y3NYXKXOqToRl8uRmoek7vUEXNzcOzY8OeZ1pvJ5C613GpaH7axm9abzeQutVjRtL64CHv2VF+PxsfmLrXQ0hLs3Tv8eevXw8mTlZejCXBYRmqR1emNRRr74qKNvU1M7lJLFJ3eODsLR49WX48my+QuNVzRxUjQT+s29nYyuUsNZlrXhZjcpQYquhgJTOtdYXKXGqbo9EY3+uoWk7vUEEUXI7ktbzeZ3KUGcDGShmVzl2rMxUgqymEZqaZmZlyMpOJs7lLNrI6tnzo13Hmzs5DpMIz6HJaRamRmZvimDo6t69VM7lINmNZVNZu7NEG9HqxbV3xs3cVIuhCHZaQJWViA5eXhz3Mxki6HyV0as9W0PmxjdzGShmFyl8bItK5xMblLY2Ba17iZ3KURK7otr2ldZZjcpREpui2vaV1VKJXcI2I98ADwNiCBfwB+CnwT2AT8AvhQZroYWp1SdDGSaV1VKZvcvwx8JzPfArwdOAzcCyxn5mZgefBY6oSii5FM66pa4eQeEW8E3g38PUBmvgS8FBG3A7cOXrYf+CHw2TJFSnXX68HOnf3VosMyrWsUyiT3NwErwNci4umIeCAirgSuzczjAIPba853ckTsjoiDEXFwZWWlRBnSZC0swI4dwzf2qSnTukanTHOfBt4B7M3Mm4A/MMQQTGbuy8xtmbltw4YNJcqQJqPo9EaA+Xk4cwa2b6++LgnKNfcjwJHMfHLw+FH6zf7FiNgIMLg9Ua5EqX62bi2W1qen+2n9wIHR1CWtKtzcM/PXwK8i4s2DQ/PAs8ATwK7BsV3A46UqlGqk1ys3vfH0adO6xqPsIqaPA72IeC3wM+Aj9H9hPBIRdwMvAHeWvIZUCy5GUpOUau6Z+RNg23memi/zc6U66fX6QzDDioCvf92krslw+wHpIkzraiq3H5DOo+jYutMbVRcmd+kcRdP6/LyzYFQfJndpoGxat7GrTkzuEqZ1tY/JXZ1WdFte07rqzuSuzpqbg2PHhj9vcRH27Km+HqlKNnd1ztIS7N07/Hnr18NJv5lADeGwjDplbq5YY19ctLGrWUzu6oSFhWK7N5rW1VQmd7Xa6vTGIo3dtK4mM7mrtYpOb5ydhaNHq69HGieTu1qn6GIk6Kd1G7vawOSuVjGtS30md7VC0cVIYFpXO5nc1XhFFyO5dYDazOSuxlpN68M29vXr+999amNXm9nc1UguRpIuzmEZNYpbB0iXx+SuxpiZMa1Ll8vmrtpbHVs/dWq482Zn+2Pr7uCoLrK5q7Z6PVi3rnhad3qjuswxd9VS0Y2+XIwk9ZncVSurab3oRl82dqnP5K7aKJrWt2yBQ4eqr0dqMpO7Jq5oWo/of4+pjV16NZO7JqroRl+mdeniTO6aiKIbfZnWpctjctfYzcwMP2cdTOvSMEzuGpuii5FM69LwSif3iJgCDgJHM/O2iLgK+CawCfgF8KHMdPF3h/V6sHNnf7XosEzrUjFVJPd7gMNrHt8LLGfmZmB58FgdtbAAO3YM39inpkzrUhmlmntEXAd8AHhgzeHbgf2D+/uBO8pcQ81UZjHS/DycOQPbt1dfl9QVZZP7l4DPAGfXHLs2M48DDG6vOd+JEbE7Ig5GxMGVlZWSZahOiqb16el+WvdLNKTyCjf3iLgNOJGZTxU5PzP3Zea2zNy2YcOGomWoRsouRjp92rQuVaXMB6q3AB+MiPcDrwPeGBEPAS9GxMbMPB4RG4ETVRSqenMxklQvhZN7Zn4uM6/LzE3AXcD3M3MH8ASwa/CyXcDjpatUbfV6LkaS6mgUi5juBx6JiLuBF4A7R3AN1YBpXaqvShYxZeYPM/O2wf3/ycz5zNw8uP1tFddQfZjWpfpz+wENxbQuNYPbD+iyFE3rLkaSJsPkrksqmtbn552zLk2KyV0XVDat29ilyTG567xM61Kzmdz1CkW/RMO0LtWLyV1/NjcHx44Nf97iIuzZU309koqzuYuFhWK7N65fDyfdqV+qJYdlOmz1A9MijX1x0cYu1ZnJvaOKfmBqWpeaweTeMUWnN4JpXWoSk3uHFE3rs7Nw9Gj19UgaHZN7BxSd3gj9tG5jl5rH5N5yRac3uhhJajaTe0utpvVhG/v69f3vPrWxS81mc2+huTnYu3f48/zAVGoPh2VaZGmpWFN3eqPUPib3lpiZMa1L+gube8Otjq2fOjXcebOz/bF194SR2slhmQabmRm+qYMbfUldYHJvINO6pEuxuTdIrwfr1hUfW3cxktQdDss0RNFtebds8cuppS4yudfcaloftrFH9L8ZycYudZPJvcZM65KKMrnX0OoHpqZ1SUWZ3Gum6PRG07qktUzuNVF0eqNpXdL5mNxrwLQuqWom9wkyrUsaFZP7BPR6sHNnf7XosEzrki5H4eQeEddHxA8i4nBEHIqIewbHr4qI70XEc4PbmerKbb6FBdixY/jGPj1tWpd0+coMy5wBPp2ZbwVuBj4WEVuAe4HlzNwMLA8ed17ZxUinT8P27aOpTVL7FB6WyczjwPHB/d9HxGFgDrgduHXwsv3AD4HPlqqy4VyMJGncKvlANSI2ATcBTwLXDhr/6i+Aay5wzu6IOBgRB1dWVqooo3bcOkDSpJT+QDUi3gB8C/hkZv4uIi7rvMzcB+wD2LZtW4GPFutt61Z49tnhzzOtS6pCqeQeEa+h39h7mfnY4PCLEbFx8PxG4ES5Epul1+sn72Ebu2ldUpUKJ/foR/SvAIcz84trnnoC2AXcP7h9vFSFDWJal1QXZZL7LcBO4D0R8ZPBf++n39TfGxHPAe8dPG4107qkuikzW+bfgAsNsM8X/blNY1qXVEduP1BQ0bQ+NWValzR6bj9QQNG0Pj8PBw5UX48kncvkPoTVjb6KpnUbu6RxMblfprk5OHZs+PMWF2HPnurrkaSLsblfQtGtA9avh5MnKy9Hki6LwzIXsPqBaZHGvrhoY5c0WSb38yj6galpXVJdmNzXKDq9EUzrkurF5D5QNK3PzsLRo9XXI0lldD65l03rNnZJddTp5O7WAZLaqpPJvexiJBu7pLrrXHJ3MZKkLuhMc19agr17hz/P6Y2SmqgTwzJzc8Uau9MbJTVVq5O7aV1SV7U2uc/MmNYldVfrmvvqTJhTp4Y7b3YWMv3QVFI7tGpYZmZm+KYOzoSR1D6tSO6mdUl6pUY3914P1q0rPrbu1gGS2qqxwzK9HuzYMfx5bh0gqQsam9w/+tHhXh/h1gGSuqOxyf2Pf7z815rWJXVNY5P75TCtS+qqxib3SzGtS+qyxib3+fnzHzetS1KDm/uBA69u8PPzcPYsbN8+mZokqS4aPSxz4MCkK5CkempscpckXdjImntEvC8ifhoRz0fEvaO6jiTp1UbS3CNiCvhn4O+ALcCHI2LLKK4lSXq1USX3dwLPZ+bPMvMl4BvA7SO6liTpHKNq7nPAr9Y8PjI49mcRsTsiDkbEwZWVlRGVIUndNKrZMnGeY/mKB5n7gH0AEbESEb8scb2rgd+UOL9puvZ+wffcFb7n4fz1hZ4YVXM/Aly/5vF1wLELvTgzN5S5WEQczMxtZX5Gk3Tt/YLvuSt8z9UZ1bDMfwKbI+LGiHgtcBfwxIiuJUk6x0iSe2aeiYh/BP4VmAK+mpluCCBJYzKyFaqZ+W3g26P6+efYN6br1EXX3i/4nrvC91yRyMxLv0qS1ChuPyBJLWRzl6QWanRz79r+NRFxfUT8ICIOR8ShiLhn0jWNS0RMRcTTEfEvk65lHCJifUQ8GhH/Pfjz/ptJ1zRKEfGpwd/pZyLi4Yh43aRrGoWI+GpEnIiIZ9YcuyoivhcRzw1uZ6q4VmObe0f3rzkDfDoz3wrcDHysA+951T3A4UkXMUZfBr6TmW8B3k6L33tEzAGfALZl5tvoz7C7a7JVjcyDwPvOOXYvsJyZm4HlwePSGtvc6eD+NZl5PDN/PLj/e/r/w89d/Kzmi4jrgA8AD0y6lnGIiDcC7wa+ApCZL2XmqYkWNXrTwOsjYhq4gossemyyzPwR8NtzDt8O7B/c3w/cUcW1mtzcL7l/TZtFxCbgJuDJCZcyDl8CPgOcnXAd4/ImYAX42mAo6oGIuHLSRY1KZh4FvgC8ABwH/jczvzvZqsbq2sw8Dv0AB1xTxQ9tcnO/5P41bRURbwC+BXwyM3836XpGKSJuA05k5lOTrmWMpoF3AHsz8ybgD1T0T/U6Gowx3w7cCMwCV0bEjslW1XxNbu5D7V/TFhHxGvqNvZeZj026njG4BfhgRPyC/tDbeyLiocmWNHJHgCOZufqvskfpN/u2WgB+npkrmXkaeAx414RrGqcXI2IjwOD2RBU/tMnNvXP710RE0B+HPZyZX5x0PeOQmZ/LzOsycxP9P+PvZ2arU11m/hr4VUS8eXBoHnh2giWN2gvAzRFxxeDv+Dwt/gD5PJ4Adg3u7wIer+KHNvYLsju6f80twE7gvyLiJ4Nj/zTY6kHt8nGgNwguPwM+MuF6RiYzn4yIR4Ef058R9jQt3YYgIh4GbgWujogjwH3A/cAjEXE3/V90d1ZyLbcfkKT2afKwjCTpAmzuktRCNndJaiGbuyS1kM1dklrI5i5JLWRzl6QW+n+8R9YaM+j5ZwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 生成数据\n",
    "def f(x):\n",
    "    return 10*x + 4 + np.random.uniform(0,1,size=1)[0]\n",
    "x = np.linspace(0,10,1000).astype(np.float32)\n",
    "y = f(x).astype(np.float32)\n",
    "\n",
    "x = x.reshape(1000,1)\n",
    "y = y.reshape(1000,1)\n",
    "\n",
    "#画出图像\n",
    "import matplotlib.pyplot as plt\n",
    "plt.plot(x, y, 'bo')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b01484e4",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 190.23776245117188\n",
      "epoch: 1, loss: 22.192134857177734\n",
      "epoch: 2, loss: 5.180789470672607\n",
      "epoch: 3, loss: 3.4329521656036377\n",
      "epoch: 4, loss: 3.2278966903686523\n",
      "epoch: 5, loss: 3.1790287494659424\n",
      "epoch: 6, loss: 3.1462202072143555\n",
      "epoch: 7, loss: 3.1153063774108887\n",
      "epoch: 8, loss: 3.0848538875579834\n",
      "epoch: 9, loss: 3.054716110229492\n"
     ]
    }
   ],
   "source": [
    "# 制作数据\n",
    "def gen_data():\n",
    "    # 生成数据\n",
    "    def f(x):\n",
    "        return 2 * x + 4 \n",
    "\n",
    "    x = np.linspace(0, 10, 1000).astype(np.float32)\n",
    "    y = f(x).astype(np.float32)\n",
    "\n",
    "    x = x.reshape(1000, 1)\n",
    "    y = y.reshape(1000, 1)\n",
    "    return x,y\n",
    "\n",
    "# 转换与初始化参数\n",
    "def trans_tensor(x, y):\n",
    "    # 转换张量\n",
    "    x_train = torch.from_numpy(x)\n",
    "    y_train = torch.from_numpy(y)\n",
    "\n",
    "    # 初始化参数\n",
    "    w = Variable(torch.randn(1), requires_grad=True)\n",
    "    b = Variable(torch.zeros(1), requires_grad=True)\n",
    "\n",
    "    # 构建线性回归模型\n",
    "    x_train = Variable(x_train)\n",
    "    y_train = Variable(y_train)\n",
    "    return x_train, y_train, w, b\n",
    "\n",
    "# 线性模型\n",
    "def linear_model(X, w, b):\n",
    "    return X * w + b\n",
    "\n",
    "#计算误差\n",
    "def get_loss(y_predict, y):\n",
    "    return torch.mean((y_predict - y) ** 2)\n",
    "\n",
    "x,y = gen_data()\n",
    "x_train, y_train, w, b = trans_tensor(x, y)\n",
    "\n",
    "'''\n",
    "    模拟单次\n",
    "'''\n",
    "# y_predict = linear_model(x_train, w, b)\n",
    "# loss = get_loss(y_predict, y_train)\n",
    "# #自动求导\n",
    "# loss.backward()\n",
    "# print('loss: {}'.format(loss.data))\n",
    "'''\n",
    "    模拟10个epoch\n",
    "'''\n",
    "for epoch in range(10):\n",
    "    y_predict = linear_model(x_train, w, b)\n",
    "    loss = get_loss(y_predict, y_train)\n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "\n",
    "    w.data = w.data - 1e-2*w.grad.data\n",
    "    b.data = b.data - 1e-2*b.grad.data\n",
    "    print('epoch: {}, loss: {}'.format(epoch, loss.data))\n",
    "\n",
    "    w.grad.zero_()\n",
    "    b.grad.zero_()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cdba90c",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# PyTorch全连接层原理和使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3f66cd45",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0, loss: 0.448427\n",
      "epoch 1, loss: 0.210626\n",
      "epoch 2, loss: 0.110089\n",
      "epoch 3, loss: 0.041647\n",
      "epoch 4, loss: 0.020220\n",
      "epoch 5, loss: 0.014911\n",
      "epoch 6, loss: 0.006953\n",
      "epoch 7, loss: 0.003364\n",
      "epoch 8, loss: 0.001514\n",
      "epoch 9, loss: 0.000910\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "\n",
    "import torch.utils.data as Data\n",
    "\n",
    "# 制作数据\n",
    "def gen_data():\n",
    "    # 生成数据\n",
    "    def f(x):\n",
    "        return 10 * x + 4\n",
    "\n",
    "    x = np.linspace(0, 10, 1000).astype(np.float32)+ np.random.uniform(0,1,size=1)[0]\n",
    "    y = f(x).astype(np.float32)\n",
    "\n",
    "    x = x.reshape(1000, 1)\n",
    "    y = y.reshape(1000, 1)\n",
    "    return x,y\n",
    "\n",
    "# 转换与初始化参数\n",
    "def trans_tensor(x, y):\n",
    "    # 转换张量\n",
    "    x_train = torch.from_numpy(x)\n",
    "    y_train = torch.from_numpy(y)\n",
    "\n",
    "    # 初始化参数\n",
    "    w = Variable(torch.randn(1), requires_grad=True)\n",
    "    b = Variable(torch.zeros(1), requires_grad=True)\n",
    "\n",
    "    # 构建线性回归模型\n",
    "    x_train = Variable(x_train)\n",
    "    y_train = Variable(y_train)\n",
    "    return x_train, y_train, w, b\n",
    "\n",
    "def linear_model(X, w, b):\n",
    "    return X * w + b\n",
    "\n",
    "# 定义网络模型\n",
    "class MatrixLinearNet(nn.Module):\n",
    "    def __init__(self, learning_rate):\n",
    "        super(MatrixLinearNet, self).__init__()\n",
    "        self.w = Variable(torch.randn(1), requires_grad=True)\n",
    "        self.b = Variable(torch.zeros(1), requires_grad=True)\n",
    "        self.learning_rate = learning_rate\n",
    "        # self.linear = nn.Linear(n_feature, 1)\n",
    "\n",
    "    # forward 定义前向传播\n",
    "    def forward(self, X):\n",
    "        y = X * self.w + self.b\n",
    "        return y\n",
    "\n",
    "    def update_params(self):\n",
    "        self.w.data = self.w.data - self.learning_rate * self.w.grad.data\n",
    "        self.b.data = self.b.data - self.learning_rate * self.b.grad.data\n",
    "\n",
    "        self.w.grad.zero_()\n",
    "        self.b.grad.zero_()\n",
    "\n",
    "# 定义网络模型\n",
    "class NNLinearNet(nn.Module):\n",
    "    def __init__(self, n_feature):\n",
    "        super(NNLinearNet, self).__init__()\n",
    "        self.linear = nn.Linear(n_feature, 1)\n",
    "\n",
    "    # forward 定义前向传播\n",
    "    def forward(self, X):\n",
    "        y = self.linear(X)\n",
    "        return y\n",
    "\n",
    "\n",
    "def train(model_type):\n",
    "    x, y = gen_data()\n",
    "    x_train, y_train, w, b = trans_tensor(x, y)\n",
    "\n",
    "    # 将训练数据的特征和标签组合\n",
    "    dataset = Data.TensorDataset(x_train, y_train)\n",
    "    # 随机读取小批量\n",
    "    data_iter = Data.DataLoader(dataset, 10, shuffle=True)\n",
    "    # 定义训练批次\n",
    "    epochs = 10\n",
    "    learning_rate = 0.01\n",
    "    if model_type == 0: # 使用矩阵方式实现网络\n",
    "        # 创建矩阵模型\n",
    "        model = MatrixLinearNet(learning_rate)\n",
    "        loss = nn.MSELoss()\n",
    "\n",
    "        for epoch in range(epochs):\n",
    "            for X, y in data_iter:\n",
    "                output = model(X)\n",
    "                l = loss(output, y.view(-1, 1))\n",
    "                l.backward()\n",
    "\n",
    "                model.update_params()\n",
    "            print('epoch %d, loss: %f' % (epoch, l.item()))\n",
    "    else: # 使用nn.Linear实现网络\n",
    "        model = NNLinearNet(len(x[0]))\n",
    "        loss = nn.MSELoss()\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "        for epoch in range(epochs):\n",
    "            for X, y in data_iter:\n",
    "                output = model(X)\n",
    "                l = loss(output, y.view(-1, 1))\n",
    "                optimizer.zero_grad()  # 梯度清零，等价于net.zero_grad()\n",
    "                l.backward()\n",
    "                optimizer.step()\n",
    "            print('epoch %d, loss: %f' % (epoch, l.item()))\n",
    "\n",
    "type_dict = {\n",
    "    'matrix': 0,\n",
    "    'nn': 1\n",
    "}\n",
    "train(type_dict['nn'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a339270",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# PyTorch激活函数原理和使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "32be6a39",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABQcklEQVR4nO3dd3yUVbrA8d+ZkimZ9B4SUiAhlGDoIB2kCYIU2+qq2K9rWxcX0FVQ17bLddVV8VqxoKJ0pINUQRAQ6Z0AISGV9GQy5dw/EiKRTmYySTjfz4cPM2857/MO8HBy5rzPEVJKFEVRlMZJ4+kAFEVRFPdRSV5RFKURU0leURSlEVNJXlEUpRFTSV5RFKUR03k6gLMFBwfL2NhYT4ehKIrSoGzdujVHShlyvn31KsnHxsayZcsWT4ehKIrSoAghjl1onxquURRFacRUklcURWnEVJJXFEVpxOrVmLyiKA2DzWYjLS2N8vJyT4dyTTEajURFRaHX6y/7HJXkFUW5Ymlpafj4+BAbG4sQwtPhXBOklOTm5pKWlkZcXNxln6eGaxRFuWLl5eUEBQWpBF+HhBAEBQVd8U9PKskrinJVVIKve1fzmaskryiK4kFSSkoLCygvKXZL+yrJK4rSaFksFk+HcFEOm43Tp9IpzM6ivNg9SV598aooSoMmpURKiUbTcPqsUkrKigopys0BwDc4BJOvn1uu1XA+FUVRlCqpqam0bNmSRx99lPbt2/Pyyy/TqVMn2rZty6RJk845fvXq1QwbNqz6/WOPPca0adPqMOLfnd171xsMBEU1xezn77bvOFRPXlGUWnlxwW72pBe6tM1Wkb5Muqn1RY/Zv38/n332GTfffDMzZ85k8+bNSCkZPnw4a9eupVevXi6NqbbO7b2HYvL1dfsX2CrJK4rSIMXExNC1a1fGjRvHsmXLaNeuHQDFxcUcPHiwXiV5h81GQU4WFaWleJlM+IaEoat6oKmitBiRl4YTLYaoBJdfWyV5RVFq5VI9bnfx9vYGKnvIEydO5OGHH77gsTqdDqfTWf2+rp7UvVjvXTqdlGWkYqIYhIMK6e2WGNSYvKIoDdqgQYP49NNPKa6anXLy5EmysrJqHBMTE8OePXuwWq0UFBSwcuVKt8flsNk4nfHHsXc/hBCUF+TiTN+PWRRgQ8txL28KAwPcEofqySuK0qANHDiQvXv30q1bN6By2uRXX31FaGho9THR0dHceuuttG3bloSEhOqhHXc4p/ceEorJp7L37nQ4qMg4gkGUgIBsYSLXAE5px+SmeISU0k1NX7mOHTtKtWiIotR/e/fupWXLlp4Oo95x2GwUZGdRUXbu2HtpdjrGitNoRAVl0kCGUU+ZrMCoM6FxBGLWGwnzNV7yGuf77IUQW6WUHc93vOrJK4qi1NLFeu92aznO7GOYRCkSLRkaM6f1DgR2fPUhFJYYcErw9nJPbCrJK4qi1ELN3rsZ35DQ33vvGakYnUXoNHaKpJFMkxar04a3zge71Y/TZRJvg5YofxMGvdYt8akkryiKchUu1nu3Fhegzc/ArCnDjo6TOm8KNTZ0aPDRhVNQrEMjICrATIBZ79a58irJK4qiXCG7zUZhdiYVZWV4mcz4hYSi1esrp0WmH8FIMQgnpzGRZRLYnTZ8vPwpK7OQb5P4m/RE+BvRa90/wVEleUVRlMt0sd57WV4WXqU5mDRWKqQX6QYvSqjASxiwaELJLxR4aQWxwWZ8jZe/slNtqSSvKIpyGWr03s1m/IIre+8Ouw37qVSMogSEIEuYyTU4kdjw1QVRVGqi3CEJthgI8zWi1dRtHX6V5BVFUS5CSklZYSFFeef23kuz0jDa8jFobJRKIxlGHeWyApPWDPYATheBSa8hNsiE2csz6VY98aooSoNTmzrxV3Ku3WbjdMZJCnOy0BuNBEc1xezrh91aRkXafky2bCRO0rVmUo0SGw589aGUlgTy10cfZcuqRTQPtXgswYPqySuKopzjor339KOYZDEIO6cdenItXlQ4bVh0vlRYfTldJrEYtPgZ9fiavDy+TGKtk7wQIhr4AggHnMCHUsq3hRCBwAwgFkgFbpVSnq7t9RRFqWcWT4BTO13bZngyDHn9sg7997//zXfffYfVamXkyJG8+OKLANx8882cOHGC8vJynnzySR566KEa5+Xk5HDTTTcxYcIEnnrqKQ4cOIBerycvN5eUlBR+Wr4Ubz+/6rH38sLT6ApPceOtf6Zrh3as3vobPQf3pluPbvzr+SkUFpUSEBjEJ598SlyTpmjOGnuPjY1ly5YtBAcHs2XLFsaNG8fq1atd9nFdjCt68nbgb1LKbUIIH2CrEGI5cC+wUkr5uhBiAjABGO+C6ymKogCwbNkyDh48eN5a8p9++imBgYGUlZXRqVMnRo8eTVBQEACZmZkMHz6cf/7znwwYMIC5c+fyww8/MKhfXz79vw+4cdBAgiKbYPLxRUon5WmHqurNSGxoOF5SzCfzP8MoLIwZfCv/+Xg6zaIjWbd0Hv965UU+/fRTD38yv6t1kpdSZgAZVa+LhBB7gSbACKBP1WGfA6tRSV5RGp/L7HG7w7Jlyy5YS/6dd95hzpw5AJw4cYKDBw8SFBSEzWajf//+vPfee/Tu3RuAe++9l9dffYVubdswY/ZsPv7oY8y+fpTmnsJQlodRY8VaNS2yQkiGjr4Jb00kW7ft5eC+vTxx92g0QuBwOIiIiPDY53E+Lh2TF0LEAu2ATUBY1X8ASCkzhBChFzjnIeAhgKZNm7oyHEVRGrkL1ZJfvXo1K1asYOPGjZjNZvr06VNdQ16n09GhQweWLl1Kr169KCssIKlpFMePn+C3/QcQGi1t2rTGmnagst6MEGQKM3lV0yJ1Qo9eE0lRqYZAs57kNq3ZuHHjReM8u559XdWyP8Nls2uEEBZgFvCUlPKy1wKTUn4opewopewYEhLiqnAURbkGXKiWfEFBAQEBAZjNZvbt28fPP/9cfY4Qgk8//ZS9e/Yw6blnKczJxsto5N6xYxn7wIPcOXoEIusABk0JpRg4YjSRo7dj1JkwOCOxOzTotRqah1no3vE6srOzq5O8zWZj9+7d58QZGxvL1q1bAZg1a1YdfDK/c0mSF0LoqUzw06WUs6s2ZwohIqr2RwBZFzpfURTlagwcOJA//elPdOvWjeTkZMaMGUNRURGDBw/GbrfTtm1bnn/+ebp27VrjPGtxEe+88Srr1q/n27nz8Q+P5LYxozmdl8c9w3sikaRpvTlmkNiFE199GCXFAZRXCLx0GqICTJj0Wry8vJg5cybjx4/nuuuuIyUlhQ0bNpwT56RJk3jyySfp2bMnWq17CpFdSK3ryYvK+UGfA3lSyqfO2v5vIPesL14DpZR/v1hbqp68ojQMDbWe/NlPrRrMZnyrZs6Uph9h4YI5zF/2I+/+93/JNApsTjsWvR/Wch/KbRJfo55IfxNeOs8+XuSJevLdgT8DO4UQ26u2PQu8DnwnhLgfOA7c4oJrKYqiXLHKee8FVfPeRfW8d2thHrIoi/EvvMjiHzfw0bcfk+blQC+88NFGUFCkQacVxASa8DW5t1qku7hids164EJ33r+27SuKotSG3VZBYXZWjd670GiwnjyEQZSCkEx65UX+4iVxSic+XgEUl5j5+4S/sXPb5hqVIp988knGjh3rwbu5cuqJV0VRGiUpJaWFBRT/ofdelpOBseI0Rk0F5dJA+pll+LSVy/CdLgSDTsMnH36At6Hhp8iGfweKoih/cL7eu5RObCcPYhIlSDQXXIYvzNdAiI8BTQMcmjkfleQVRWk0Lth7P3Uco7MQncZOsTRyqnoZPgv2Cj9OF4G3l5YmASaMblqGz1NUklcUpVE4p/ceEoq9vAz7yf2YNWU40JGm86agahk+X10Y+cX6qmX4jASYPV9MzB1UklcUpUE7u/cuEPiFhGGwWLBmHMVICQgHpzGRbRLYnDZ89P6UlVs4bZP4mXRE+pvqZBk+T2m8d6YoSqNnt1VwOv0kRTnZeBlNBEU3BYcVmb4PkyjEhpZjBgvpXg6E0GDRRJBf5I3TKYgN8iYmyLs6wffp04fG+JyO6skritLgVPfec3MQQuAXGoaXyXTWMnyQVVVvxlm1DF9xqYniCjthfp5Zhs9TVJJXFKVW3tj8Bvvy9rm0zaTAJMZ3Pn/RWrutgl1btzLmzjvp1KEDu/fuo1lsU778z4tYzFrKqqZFVi7DZ6JX2z4Mv/UuNq1bxROPP0ZEaAijJ03CarXSrFkzPvvss3NWi7JYLNX1cGbOnMkPP/zAtGnTXHqPdUUN1yiK0iBIKSkpyCf3xHHstgoOHznKI488wpbFM/H31jH1829I15g5aoSKqmmRZSVBSAnBft788vMGhg4exD//+U9WrFjBtm3b6NixI2+++aanb82tVE9eUZRauVCP25XsFVUzZ8rLMJi9CYiIJCoykl4tQhGilFGjRvL2tK+4UWfHW+eDzepXvQyfTqvh/rvvQiMEP//8M3v27KF79+4AVFRU0K1bN7fH70kqySuKUm9JKSktqJo5o6kcexc4kal70QgnDgQZOm+y9U6EEPjowiko0qHVQHSAGX+zHgF4e3tXtzdgwAC++eabi1737KmUdV3/3dXUcI2iKPWSvaJq5kxuNl4mM4FNmkJhFobCY2iFleMnTzFz1y4KNTaWz11J+069yS/W4m/WkxjmQ4D3ufPeu3btyk8//cShQ4cAKC0t5cCBA+dcOywsjL179+J0OqtXl2qoVJJXFKVekVJSkp9Pbtpx7DYrfqFhGAxaROZBTKKICnSc9DIRnxjPvG/nMab3rWRnF3H73Q8QF+xNdKAZ3QXmvYeEhDBt2jTuuOMO2rZtS9euXdm379wvjV9//XWGDRtGv3796t1yfleq1vXkXUnVk1eUhsFd9eTtFRUUZGdhqxp7twQE4Mw5gZcoBQRZGiN5Xk7Sjqfx2J2PM2fFJhxOCPbxIszHiOYamBbpiXryiqIotVI59p5PcV5u9di7LC1Am3sYvcZGqTSSYdRVTYv0xiDDcDgEep2GOH8TJi+Vyi5EfTKKonjUH3vvZh8LmvwMdKIUJ1pOar0p0NnRCAe+ulAKir0IDAvg5y3bCbL8Pu4+cuRIjh49WqPtN954g0GDBnnituoNleQVRfGI8/bei3LxKsgB4aBQmMg0arA5bVh0vlRYfTldKvE16s67DF9D/4LUXVSSVxSlzlX23jOxlZdj8PbGaPRCX3ASnaYcG3oyvIwUCRt6NFXL8GnRagRNA034NdBl+DxFJXlFUerMH3vvviGhaAqzMZRkg5DkYCLHCA5pw9crkJJSM/l2SaC3nnBf4wVnzSgXppK8oih14o+9dy8tGIvS0FQtw5dh1FMqKzBqjRgd4dXL8MWHmLA0gmX4PEV9coqiuNUfe+8+QUHoi3PQO0qRCE5pzOSdZxm+UF8DoRZDraZFnl1o7FqlkryiKG5zTu9d2jCVpKPR2Gosw2fWWXBe5TJ8UkqklGg0aijnfNSnoiiKy1U+tXqa3LTjOGwVWPz9sdgL8ZZ5SCRpOm+OGZw4kPjqwigp9qe8QtDE30R8iPclE3xqaiotW7bk0UcfpX379rz88st06tSJtm3bMmnSpHOOX716NcOGDat+/9hjjzXY0sFXSvXkFUWplVOvvop17++lAaSU2G0VSKcTodGiwUEpTkDiQEOFpvIYnUaH06ml0Al6rcBLp6UYKAYMLZMIf/bZi153//79fPbZZ9x8883MnDmTzZs3I6Vk+PDhrF27ll69ern1vhsK1ZNXFMUlJOCw27FZy5FOJ1qtFp20o8WBBKxCi1VIQKAVBux2LUiBUa/FqNNecTKKiYmha9euLFu2jGXLltGuXTvat2/Pvn37OHjwoOtvsIFSPXlFUWol/Nlna469m82YHKUYRAkAORoTuV4SJ0589IEUl5iwOiTBFgNhvga0VzmWfnb54IkTJ/Lwww9f8FidTofT6ax+39DLB18J1ZNXFOWq1Rx7t2ExG/Bz5mPUFFOOF0eN3mTp7XhpvTARyelCI1qNhuYhFiL9TVed4M82aNAgPv300+pZNCdPniQrK6vGMTExMezZswer1UpBQQErV66s9XUbCtWTVxTlqtTovZtMeDtL0DsLkGjJ0Jg5fZ5pkeG+BoJ9DGhc+MTqwIED2bt3b/UKTxaLha+++orQ0NDqY6Kjo7n11ltp27YtCQkJtGvXzmXXr+9UqWFFUa6I0+lgx6/biQjwQ2g0eOsFJorRCDtF0sQpk4YKpw1vvQ92qx+lFRKLQUcTfxOGy5wWqVzYlZYaVsM1iqJcttyTJ/j2+b9jLSnGYPAiUFeBt8jHCZzQeXPc4MCJxEcXTlGxH1Y7RAWYiQv2VgneQ1wyXCOE+BQYBmRJKdtUbQsEZgCxQCpwq5TytCuupyhK3XI6HWz9YS4/ffcVeoMBo07gx2kQkjxMZJsEdqcNH70/ZWUW8u0Sf7OeCD8jelVvxqNc9elPAwb/YdsEYKWUMgFYWfVeUZQGJjetsve+dvpnJDQJ4Z6wQ3iJMmzoOGbwJsPLgUZosWgiyS/yRkpBXLA3TQPNKsHXAy7pyUsp1wohYv+weQTQp+r158BqYLwrrqcoivs5nQ62LJjDhu+nYzQYuLmpkzjjHJxoKRIGDhs1SGz46oIoKjFR7pSE+BgI9TGivQaW4Wso3Dm7JkxKmQEgpcwQQoSe7yAhxEPAQwBNmzZ1YziKolyu3LQTLJ36FhmH9nNdVAA9TVsx6E7ymzOJV5ICeVjrIFhrBHsAp4vApNcQF6yW4auPPP4nIqX8EPgQKmfXeDgcRbmmOR0OtvxQ2Xv3MRi4I6aMCNN6yqQ/k/w7MT8oF7MoxKT1obQkEIAIPyPBZy3Dp9Qv7kzymUKIiKpefASQdckzFEXxmNy04yyZ+hanDh2gWxMznS2b0IrTrBRteLOVmRPlGVwfcgNphwdQUq4jwqCjib8RL52aNVOfufNbkfnAPVWv7wHmufFaiqJcJafDwaa53/Pl+CdwZp7kvrjTXO+7lAKMPNmkG3+NK8CuhT5+z7Ji7QAyT+sJ9PYiNshc7xO8xWK5qvOklDz33HMkJibSsmVL3nnnHRdHVndcNYXyGyq/ZA0WQqQBk4DXge+EEPcDx4FbXHEtRVFcJ+fEMZZOfYtThw8woImONj5rARvTdSl81BxOV6RzQ8StbN3eiQW5Dm7p0ITnhrYk49jhejM844568tOmTePEiRPs27cPjUZzTpmEhsRVs2vuuMCu/q5oX1EU13I6HPwyfxYbZ35NtLeeR5qn460/whFnPC82j2CbPEGiKYnm/JU5K/XEBRv5+sE2XN8sGICMs9pa990Bck64dvWl4GgLPW9NvOD+1NRUhgwZQt++fdm4cSM333wzP/zwA1arlZEjR/Liiy/WOH716tVMmTKFH374AaisJ9+xY0fuvffe87Y/depUvv766+r/OM4ukdDQePyLV0VR6lbOiWMsef8t8o4e4qYmduIt67Bh4H+9O/BdZAlOshkU8gArN7WgqMzJ4/2a8Ze+zS97paa64s568ocPH2bGjBnMmTOHkJAQ3nnnHRISElwYfd1RSV5RrhFn995b+uq5NeEAXtpTbHG25NVWARy0ptEhoBulGcOZ+aOgXVMfXh/VlhbhPhdt92I9bnc6U09+3Lhx1fXkAYqLizl48GCtkrzVasVoNLJlyxZmz57Nfffdx7p161wVep1SSV5RrgE5x1NZMvUtio8f4baoYiJMWyl2BjE5qDOL/XLwoYQBwU+zaGMYOo2Wl0a04K4uMbVaRNvd3FlPPioqitGjRwMwcuRIxo4d64KIPUM9c6wojZjT4WDTnO/4auKTxBcd54FmvxFu/JVFoi2jW8ezwPcU3cMHYsyayOx1ofRMCGX50724u1tsvU7wZ3NHPfmbb76ZH3/8EYA1a9aQmOiZn1ZcQfXkFaWROtN7d55M5d6mWfgb9nDKGc2L0a1Zrz9BlC6a3j7Ps2iVN8EWPR/c1YbBbcI9HfYVc0c9+QkTJnDnnXfyn//8B4vFwscff+zWe3AnVU9eURoZp8PB5nkz+Xnm1/QPk7T23YRE8rmhFZ/F2Sm2ldAv4lY2bm1Her6Du7o25e+Dk/A16i/7Gueraa7UjSutJ6968orSiGQfT2Xp1LcwZx7jofhUzLpU9jub83JiOL85jtPSuw3motuZtVJHQqiJmY8k0zE20NNhK26kkryiNAIOu51f5s9iy+xvGBpeTmz0JiqkhdcsHZkdXoBWk8ug4P9h6c9xWG3wtwHNebh3M7x01/bXciNHjuTo0aM1tr3xxhsMGjTIQxG5nkryitLAnem9h+Ud56G43Xhps/hJtuFfrXw4Yj1Jp6CenD4+jJk7JJ3j/HltVDLNQq7ucf/GZs6cOZ4Owe1UkleUBspht/PLvJnsmvctN0XmExbxK/kyknGWQawO3oe5RNAscyxr9iaiB27WmGl33MmGt3ewoZbXThpqJveka59yvdZ5mXT4BBpd3q5K8orSAGUfT2XJe/9LUulxxsZtRUMZM7W9eDuqkHzNXlrb+pGa1o/t5To6+pi5PTQQP53r/rlrdQ70hvr1BGxDp9W7Z+hMJXlFaUDO9N6PLpzBiIg0fIMOcFwm8mJ0GJv1qcT6xtPa/jhLtppo4m/i09tb0y8pzOVx7N27F99gk8vbVVxPJXlFaSCyjx1l6bv/pqP9GLdHb8KJnvcMPZkek0e5I52B4XezdnMyu4sd3N8jjqcHJOJtUP/E7733XoYNG8aYMWM8HYpHqL8BilLPOex2Ns/7ntwlMxgTfgCj9iQ7ZHv+mWBmr+MYyb7t0ObewqyVGlpFePPJPcm0jfL3dNhKPXFtz59SlHou+9hRvnvmEaI3fcSwJqtwCiuTfHoytnkhadrTDAp7nB2/3MEvh3RMHJLE/Me6XxMJvqSkhKFDh3LdddfRpk0bZsyYwdatW+nduzcdOnRg0KBBZGRknHNebGwsOTk5AGzZsoU+ffrUceR1T/XkFaUectjtbJ77HY5V33BL0G9oRRHL6c5brRwctx7j+tD+pB8ZyMzfJD0TAnl1ZDLRgWaPxLpq2odkHTvi0jZDY+Lpe+9DF9y/ZMkSIiMjWbhwIQAFBQUMGTKEefPmERISwowZM3juuef49NNPXRpXQ6SSvKLUM9nHjrLmzUn0995HQMh+smU8L4WlsMY7lTBtOP39J/LDGn98TXreuq0VI1Ii680qTXUlOTmZcePGMX78eIYNG0ZAQAC7du1iwIABADgcDiIiIjwcZf2gkryi1BMOu51NM7/C9+dvGOW/DdDypaYPnyYWkFdxjH4Ro9n+W1fm5tgZ3b5yGb5Aby9Ph33RHre7JCYmsnXrVhYtWsTEiRMZMGAArVu3ZuPGjRc97+ySw5cqN9xYqCSvKPVAVuoRtr05nt5+OzD5n+KQbM9LsT78qjlCc2MiCTzB3JUGYoK8+Or+DvRICPZ0yB6Vnp5OYGAgd911FxaLhQ8//JDs7Gw2btxIt27dsNlsHDhwgNatW9c4LzY2lq1btzJkyBBmzZrloejrlkryiuJBDruNjV+8S+y+7xkcuBurM5R/G25gVtOTOGQxg0Lv58dNSewoc/I/feJ5sn9CvVuGzxN27tzJM888g0ajQa/XM3XqVHQ6HU888QQFBQXY7Xaeeuqpc5L8pEmTuP/++3n11Vfp0qWLh6KvW6rUsKJ4SNbRQxx9+3E6+P2GFisbZR+mJFk5aDtGu+DOVGSO5OcDguui/Xl9VDItI3w9HXI1VWrYc1SpYUWp5xx2G7+8O5FW2Qvp4p/OaUcrXg1swoqgw1g0FgYEP8nin5ugEYLJN7Xgz91i0TaQVZqU+kcleUWpQzl7fqZk2mN0NR/ErgtklnM4H7bOJN26n55hgzlyoB+ztzu5oWUIL41oTaS/Kh2g1I5K8opSBxzlxeyfcictbBsIMjk5bBvG602t/GzcThNdFL28nmPxap+qZfhaM6h1+DU3LVJxD5XkFcWdnE5yV07FuOZftNLlU+LszEe6OGa12UuRrYh+4bexeVsHFhU4uatLDM8MbnFFy/ApyqWoJK8o7iAljn2LKPr2KYJEFjZNNBusd/DfVgfY5dxIorkVsQW3M+9HLxLDzLz7SDIdYtQyfIrrqSSvKK4kJRxdQ9n88Zjy9+Etw0mv+Avv+mayPHElQgh6+T/Aql8ScErBM4MSeLBn/DW/DJ/iPirJK4orOJ2Quha55l+IYz+hcwZy2vEYy+z+TG+3mqPW41zn341Th29k4U49/ZNCmDy8tcfqzVxLpk2bxpYtW3j33Xc9HYpHqCSvKLVRdAq2T4dtX8Lpo9icFkocD3Pc2okPo+fyU8BufPGjreFx1q+PJNTXyOu3J9EtMRgbgv3FZTidThxVv5xVv+wOJyCRsvIXUP26+smWPzzjcqFnXtzxLIyP3U6RtcLl7bpDuc1OhcNR7+PVaQQmveu/j1FJXlGulNMBh1bA1s/hwBKQDtKtYeSLh1gf2JfV5p3s932PCkcuTtv1HDs4gGN2E45YC8eb+fDU6RzYlOPpu6iVL/wFGqvdozEs+GY6X/z3bYQQJLRuw99efZ1XnnqSjLQTADzz+r9o17UbWXYHhQ7JEaud5//nIXoNGsKAm0cC0C0ylI3pWZ68jWpm6SDBrwEmeSHEYOBtQAt8LKV83d3XVBSXcjogcxcc24Dz6Ho4tgFNeR4Ven926TuzwtqS5a17syfUgrngO4ylG3HYQrGmPYSjJB7/AEnzCCcBlmIMpWV4aTUYtDq8dFq8tFp0Gg1CCLQaDRqNqH4thEAgQAiEoPI1QNXUyrNnWFbvO5sbp2D6axxEaCvbty89jswsdWn7IsyMblDTC+7ft2cP0/733yxYvpKg4GBO5+Ux8W9P8+Tjj9Pl+utJO3GC20cMZ/22X/HXCLwFRGgFZiEI0Ijq2AVUv/Y0g9Y9xebcmuSFEFrgPWAAkAb8IoSYL6Xc487rKspVcdhxFJykPPMg1qzDOHIOo8vYhuX0LvSOMgAK8OUYTdjP9WzSJrM5KpR0yym8yr4gMOMQAg2OvBsozerDkFaRPDWgBS3C6085AlfZu3cvod6VD2rl63VUaFxbT8dLr8Pf+8IPgs34eQO33XILLWOiAQj1bsL61as4cmB/9TElxUWYnHZ8DV6Y9DpCvU0YdVr8jF7VsQuoft1Yubsn3xk4JKU8AiCE+BYYAbg0yeenppK2YqYrm1Rq46qGgM8+yVlzuwSkROIECaJqv5ROhHTCmV9IwAGy6pfTDtKOkLaqbTY0jjK0znI00orWWY5WlqOX5Xg7C7BQhBaJN+BddfUsAtlJIjm6GAr18ZTpAllly2ObXx5W42p0jpN4F4CzPAJrUX9sBSkMbtGGJ+9IIOkSyV1KBw5HGVLacDrtSGlDSvvvr5EgnVX37ax+D7Lq9R/G5C/0wbthTN7p9MJuLwHAMiTc5e0D1e2ff185UtpqHON0Oli3bgUmU82k7XBYcTorj9VowGYrxW4vQUpJRUXFRa9Tl4TQotUaXd6uu5N8E+DEWe/TAJeXfstYu4A2aS+7ulmlkaqQeqx4UYEXFRiwSiPZMp4iZxBFjmCKHSEU2UPZ5wznFL7kaSTZllQKTbsQvmvQeOUipYCypngVDSVe24nkxJbEh5hpH+VFtN9pyss3ceJEOuXWdMrL06mwZmN3lOBwFGO3l+BwlOJ0lnn6o7hqgQH/R2mp56phXn99c+688z0eeugmAgP9ycsroG/fLrz11qs8+eRYAHbs2EfbtklUVGRjtxdSWnqEJk182LRpNUOHtuOHH37EZrNRWuraVa2ulk7vh9l04SGqq27X5S3WdL7BrhrdCiHEQ8BDAE2bXt0NRiTGsXN7ylWdW3dqjqf+/v6s12f/Lv5wbI3j/vBaaP6wWXOets9z7T/uO2cMV5x7bfHHuP8Y0x9jPXvb2XGc515qxCDOc5zmPMdqfv8dTdUpWhBaBFrQ6BBCh0CDEFo0GhMarQ4QOJAUOsrJtZWQZy8h115KbkUxu21W9moDKDOfQue7HL3PboSuGI3U4nAm4lfYgb/49qPrMW+MXa04O5wiv2Aq+fm/kHEgm7NXFtVoDBiNkXh5hWI0RqDVeqPVmtFpvdHqLGi1JjRCjxB6hEaHRuiqXwu0VaUNNAihqf5MRPXSzL9/Xucdkz/nz9J1MjLMmM1xLm3zSnToEMezz05k2LCH0Wi0pKS05b//fZ8nnvgb3bvfgcNhp0eP7rz//hC8vILR6Xwxm+N45JG/MmrUbfTvP5Z+/Xrj7e3t0fs4mxDuScduLTUshOgGTJZSDqp6PxFASvna+Y6/6lLD1iLI2FGZ7ETVP/rqfxRU/aMXZ+0/s+/s488+50LHnvX7Oeefp00EaNRDLp5WaivlWOExjhQc4UjBEY4WHOVowVGOFR7D5rQBlSMaxvKulGT2ptwG3pGzEd4HAS/KzSmgS6avtgUvdWpN8eG55JxcTXnQIRyayh/1DYYIAvw74+PTGqOxCUZjJEZjJHp9UKOsQaNKDXtOfSs1/AuQIISIA04CtwN/cvVFjuUL5hwIuoozndQc/1Xqg5r9Dnmebb//OHimk+KUUOEop8RRSKm9kBJbIWWOQqyOYmzO8sqxfKlBSC06YvBytiRBGsCpJ7fUwIkCL4ornHiFbsQ7aBkIQbHfnZh0Pbm/SMdNJfspMX7Jnt1bkRo7Xn6RhEbeSEBAF/z9O2MyNXH756IoV8OtSV5KaRdCPAYspXIK5adSyt2uvs76Xw/wukH1mBUN4F/169JEsQ3d3gK0eVa0fqfwjZuN1BwnRLalX/GfSMjTE2+eT1nUGk6FZqOVFsJ0IwkPHk1giw4IVYpAaQDcPk9eSrkIWOTOa3RLDOaFRavceQnFpS49fCGRWJ3lWB3lWB1llDsqX1c4y3Cc6d0Deo0eo86EUWfCpDNj1JmxeFkwGbzRabXotFq0Wi16nRadTo/JYECr82JBahlzDuZj1Dvo0H4Tv5UvxF/vx7gWL9MrqA0ncj8mM38OhdgJ8O9KZOSzhIQMQqs1uPmzURTXahRPvMbHN+XRx+7xdBjKVSiwFlSOledXjpWfGTM/WXyyekqgQNDE0oR4/3jifOMqf/eLI843Dn+j/xVdb9W+LJ6bv4sTeWX0bptHhn4620pPMjphNH9pew+n079i69GJSCmJbHI7TaPHYjbHuv7GFaWONIokr9RvTunkVMmp6gR+9pefeeV51ccZtAZifWNpE9yG4c2GVyZyvzhifGMw6mo3f/hUQTkvLtjN4l2niAt1ckOvVWzKXk6sKZaP+r9DcPkv7NpyE1JWEBE+mtjYx9Q4u9IoqCSvuIzVYeVY4bEaiTy1IJXUwlTK7L/PCfcz+BHvF0/f6L7ViTzOL45I70i0Ln5y0uGUfPXzMf69dD82h4Ph3U+yrfgLtuaW8HDbBxnkp+Xk4ac45ighPGw4cXGP15spdYriCirJK1eswFpQ3RM/k9CPFBzhZPFJnPL32UqR3pHE+cfRIawD8f7xxPtVDrMEGutmcYy9GYVMnL2T7Sfy6ZzgwCt8LqtyfqFdaDueSBpMxamPOX76GMHB/WkWPw6LJbFO4lLc44EHHuDpp5+mVatWbrvGjTfeyNdff42/v3+N7ZMnT8ZisTBu3Di3XftqqSSvnJeUkszSzMqx8sKjHMn/fY55bnlu9XF6jZ4Y3xiSApO4Me7G6kQe4xuDWe+ZWullFQ7eXnmQj9cdwdekYVTf/azN/hp9gZ6/t3uE1s4tnD76PGZzPCnXfUZQUC+PxKm41scff+z2ayxa5NY5JG6hkvw1zuawcbzoeI0e+Zke+tlDLD5ePsT7xdMzqmd1Io/zi6OJpQk6Tf35a7TuYDbPzdnF8bxSBqRYyTZ+xfJTB+kX3Yc/hflTkvVfCjUGEpo/S1TUn9Fo3FP5T3GvkpISbr31VtLS0nA4HDz//PNMnTqVKVOm0LFjRz755BPeeOMNIiMjSUhIwGAw8O6773LvvfdiMpnYt28fx44d47PPPuPzzz9n48aNdOnShWnTpgHwzTff8OqrryKlZOjQobzxxhsAxMbGsmXLFoKDg3nllVf44osviI6OJiQkhA4dOnjwE7mw+vOvU3Gr4oriGmPlZ34/UXQCh3RUHxfuHU6cbxyjEkbVmMkSZKzfT27mFlt5ZeFeZv96kpgQLSP6bWZVxhxCtCG82O7PhBbOpTgzk4iIW2jWbBwGr2BPh9xoLF68mFOnTrm0zfDwcIYMGXLB/UuWLCEyMpKFCxcCUFBQwNSpUwFIT0/n5ZdfZtu2bfj4+NCvXz+uu+666nNPnz7Njz/+yPz587npppv46aef+Pjjj+nUqRPbt28nNDSU8ePHs3XrVgICAhg4cCBz587l5ptvrm5j69atfPvtt/z666/Y7Xbat2+vkrziflJKssuyf0/k+b/3yrPKfl8YQafREeMTQ3P/5gyIGUCcXxzxfvHE+sXirfe+yBXqHykls7ed5J8L91BUbufm6/PZZZ3GqoxMxjQfyg3GNMpy/g8vS2uSk9/Hzy/F0yErLpCcnMy4ceMYP348w4YNo2fPntX7Nm/eTO/evQkMrPzu55ZbbuHAgQPV+2+66SaEECQnJxMWFkZycjIArVu3JjU1lWPHjtGnTx9CQkIAuPPOO1m7dm2NJL9u3TpGjhyJ2Vw5JDl8+HB33/JVU0m+AbI77ZwoOnHOdMSjBUcpthVXH2fRW4j3i6drZNfqIZZ4v3ia+DRBr3H9CjR1LTWnhOfm7uSnQ7m0jRGExy1lZeaPNPOL5+mU4ZhOz6LC7kViwvM0aXIXmno0rNSYXKzH7S6JiYls3bqVRYsWMXHiRAYOHFi971L1uAyGygfaNBpN9esz7+12Ozrd5f09qc8/2Z5N/a2vx0ptpeck8iMFRzhedBy78/el10JNocT5x3FTs5uqE3mcXxwhppAG8xfxStgcTj5ad4S3VxxEr4XRfdLYmPc5J7OtPJA0gvbOjVTkfkNw6FASE57DYAjzdMiKi6WnpxMYGMhdd92FxWKpHksH6Ny5M3/96185ffo0Pj4+zJo1q7q3fjm6dOnCk08+SU5ODgEBAXzzzTc8/vjjNY7p1asX9957LxMmTMBut7NgwQIefvhhV92eS6kk72FSSnLLc38fXjlrJktmaWb1cVqhJdonunp++ZkpibG+sVi8LB68g7q1/UQ+E2btYN+pInq0dOAI+p5lmdvpEJrCnWHeaPK/QWtqSsp10wgK6nnpBpUGaefOnTzzzDNoNBr0ej1Tp06tnr7YpEkTnn32Wbp06UJkZCStWrXCz8/vstuOiIjgtddeo2/fvkgpufHGGxkxYkSNY9q3b89tt91GSkoKMTExNYaL6hu3lhq+UlddargBsDvtnCw+ed4vP4sqiqqPM+vMNXrjZ36P9olGr234QyxXq9hqZ8rS/Xy+MZUQHy09Ou5gVeY3mHVmHmjeh7jSxTidpcTEPExszKOqxoyb1fdSw8XFxVgsFux2OyNHjuS+++5j5MiRng7LJepbqeFrzuXULgcIMgYR7x/PjXE3Vk9HjPeLJ8wc1iiHWGpjxZ5Mnp+3i1OF5QzpWMZx8TnLM1IZEN2TYd7ZyKJvsfh1JCnpn1i8EzwdrlIPTJ48mRUrVlBeXs7AgQNrfGl6rVFJ/irllef9Pre8apjlaP5R0kvSq4/RCA1Rlqjq+eVnpiTG+sbiZ7j8Hx+vVVmF5by4YA8Ld2bQPFzDsA4/sTpjAZHeEbzQehBBRQvRWo00T3qVyIhbqlZPUhSYMmWKp0OoN1SSvwiH00F6Sfo5j/AfLThKvjW/+jij1kicXxwpoSmM9BtZ46lPL6162OZKOZ2SGVtO8OqivVjtDkb1yGFbyTTWnsrjtmYD6aHZjqNwDqFhw0lIeE7NeVeUi1BJHii3l59TWOtIwRGOFR7D6rBWHxdgCCDOL44bYm6oMWYe7h2ORvUiXeJQVjHPzt7J5tQ8OsSDX/QPLM9aT1JAIs80S8JYOBe9MZpkVY5AUS7LNZXk88vzz/nS80jBEdKL02vULo+0RBLvF0+3iG6VibyqjvmV1i5XLp/V7mDq6sO8v+owRi/BmL5HWZfzJcfznDzUYghtKlYhi3bStOlDxMU9gVZr8nTIitIgNLok75ROMkoyzpmSmFqYWqN2uZfGi1i/WJKDkxnebHiNIZba1i5XrswvqXlMnL2TQ1nF9Eu2U+TzDUtP7aZbeAfG+FegLZmFxactSUmf4+PjvgqDitIYNYokf7TgKFN/m1pdv7zcUV6970zt8j7RfWoU1nJH7XLlyhSU2XhjyT6+3nScSH8tI/v/xuqM7/Ep9eGZVoOIKl6MplxLs4TniYr6M0KoPy9FuVKNIslLJL9l/UacXxwdwzrWWCYuwBCgpiTWM1JKluw6xaT5u8kptjK8SwkHndNYkX6CYTG9GWA4iiyaQ1Bwf1okTsZojPR0yEo9Y7FYKC4uvvSBtfDBBx9gNpu5++673XaNF154gV69enHDDTfU2L569WqmTJnCDz/8UOtrNIokH+8Xz9IxSz0dhnIZ0vPLeGHeblbszSSpiYZOndawKmMx0T5RTG7dF//CxegdwbRo8x4hIYPUf9CKxzzyyCNuv8ZLL73k9ms0iiSv1H9nluH715J9OKVkdK9MNhdOY8OpIu5qPpDOzk04CxfSpMmfaBb/DHq9r6dDVi7TgQMvU1S816Vt+lhakpj4/CWPk1Ly97//ncWLFyOE4B//+Ae33XYbTqeTxx57jDVr1hAXF4fT6eS+++5jzJgx521nwoQJzJ8/H51Ox8CBA5kyZUqN1Z5++eUX7r//fry9venRoweLFy9m165dTJs2jblz5+JwONi1axd/+9vfqKio4Msvv8RgMLBo0SICAwPZvn07jzzyCKWlpTRr1oxPP/2UgIAA7r33XoYNG8aYMWNYsmQJTz31FMHBwbRv395ln6Wa96e43b5ThYyeuoFJ83fTJtZB526zWJb9H6ItEbzSsh0drXMx6r3p0H4GSS1eVgleuWyzZ89m+/bt/Pbbb6xYsYJnnnmGjIwMZs+eTWpqKjt37uTjjz9m48aNF2wjLy+POXPmsHv3bnbs2ME//vGPc44ZO3YsH3zwARs3bkSrrfnd0K5du/j666/ZvHkzzz33HGazmV9//ZVu3brxxRdfAHD33XfzxhtvsGPHDpKTk3nxxRdrtFFeXs6DDz7IggULWLdunUvr86uevOI25TYH76w8yIdrK5fhG9PvIGuypqPN1/KXpMG0KF+OLLYSF/cUMTEPodGoejMN0eX0uN1l/fr13HHHHWi1WsLCwujduze//PIL69ev55ZbbkGj0RAeHk7fvn0v2Iavry9Go5EHHniAoUOHMmzYsBr78/PzKSoq4vrrrwfgT3/6U42x8r59++Lj44OPjw9+fn7cdNNNQGXN+x07dlBQUEB+fj69e/cG4J577uGWW26pcY19+/YRFxdHQkJlWY677rqLDz/8sPYfECrJK27y06Ecnpuzk9TcUgakVJBrms7SjP30jOzMSN9CNCWz8fXvTFKLV/D2jvd0uEoDdaECi1dSeFGn07F582ZWrlzJt99+y7vvvsuPP/542W39sSb92fXq7Xb7hU47h7u+f1LDNYpLnS6p4G/f/cadH29CCis399vC5orJ5FtzmdBqAGN0G/CyHqJl0mu0bzddJXilVnr16sWMGTNwOBxkZ2ezdu1aOnfuTI8ePZg1axZOp5PMzExWr159wTaKi4spKCjgxhtv5K233mL79u019gcEBODj48PPP/8MwLfffntFMfr5+REQEMC6desA+PLLL6t79WckJSVx9OhRDh8+DFSuMesqqievuISUknnb03nphz0UltkY0a2APRXTWJmRwc1x/eir24ssmkdI6I0kJryAwRDi6ZCVRmDkyJFs3LiR6667DiEE//rXvwgPD2f06NGsXLmSNm3akJiYSJcuXS5YU76oqIgRI0ZQXl6OlJL//Oc/5xzzySef8OCDD+Lt7U2fPn2uqD49wOeff179xWt8fDyfffZZjf1Go5EPP/yQoUOHEhwcTI8ePdi1a9cVXeNCVD15pdaO55by3NydrDuYQ5umGpo0W8aGUyuI94tjbFQ0PoVLMRoiaNHiJYKDLzw2qjQc9b2ePPxeUz43N5fOnTvz008/ER4eXqu2AF5//XUyMjJ4++23XRnuZVP15JU6Y3c4+WT9Uf6z4gBaAWN6n2RD/jTSs8q5N2EQ7R3rcRbuJTp6LPFxT6HTNaxFwpWGbdiwYeTn51NRUcHzzz9/1QkeYOHChbz22mvY7XZiYmJqLDdY36kkr1yVHWn5TJy9k93phXRv6YSgmSzN2ka7kGT+FKxDXzwHs6UVLa/7EF/ftp4OV7kGnW8cfuTIkRw9erTGtjfeeINBgwZdtK3bbruN2267zZXh1RmV5JUrUmK18+byA3z201GCLFpG99vDqsxvMBYZebLlYOJLlyJKHcQ3G0909H1oNOqvmFJ/zJkzx9Mh1Lla/QsUQtwCTAZaAp2llFvO2jcRuB9wAE9IKVXdgQZu1f4s/jFnFyfzyxjSoZw03RcsyzjCDVHXM9SciSieTUBAD5KSXsZkaurpcBVFofY9+V3AKOD/zt4ohGgF3A60BiKBFUKIRCmlo5bXUzwgu8jKSz/sYcFv6cSHaRnefyOr0+cT7h3GP1r1J6R4MTq7L4mt3iQsbLiqN6Mo9UitkryUci+cdxL/COBbKaUVOCqEOAR0Bi78bLFS70gp+X5LGq8s2ktZhYOR3fP4rewz1mTkMia+Hz00vyGLFhAePormzSfi5RXo6ZAVRfkDdz0M1QQ4cdb7tKpt5xBCPCSE2CKE2JKdne2mcJQrdSS7mDs++pm/z9pBfJidXj1/YEXevwgw+PFyy650ty3AqIF2KV/QqtW/VYJX6lR+fj7vv//+VZ/fp08frpXp2pfsyQshVgDnm3v0nJRy3oVOO8+2807Il1J+CHwIlfPkLxWP4l4Vdif/t+Yw/111CIMOxvQ5xvq8L0jLcfBg4mCSbauQxbtpGvMIcbGPo9WqVbSUuncmyT/66KOeDqXeu2SSl1LecKljziMNiD7rfRSQfhXtKHVo67E8JszaycGsYvokOyjz/ZalmTvpHJbCLQF29CWzK5fha/kaPpYkT4er1BPPH0xjV3GZS9tsYzHxckLUBfdPmDCBw4cPk5KSQt++fdmxYwenT5/GZrPxz3/+kxEjRpCamsqQIUPo0aMHGzZsoEmTJsybNw+TqXJ94O+//55HH32U/Px8PvnkE3r27OnSe6gv3DW/bT7wtRDiTSq/eE0ANrvpWkotFZbb+PeS/Xy16RjhvlpG9dvBqlPfYSmx8HTLgcSULEVTLtQyfEq98frrr7Nr1y62b9+O3W6ntLQUX19fcnJy6Nq1K8OHDwfg4MGDfPPNN3z00UfceuutzJo1i7vuugsAu93O5s2bWbRoES+++CIrVqzw5C25TW2nUI4E/guEAAuFENullIOklLuFEN8BewA78Bc1s6Z+qlyGbxdZRVaGdCwhVXzO8ozjDI7uyWDjMSieS1BQP1q0eFEtw6ec18V63HVBSsmzzz7L2rVr0Wg0nDx5kszMTADi4uJISUkBoEOHDqSmplafN2rUqPNub2xqO7tmDnDepwuklK8Ar9SmfcV9ThWU88K8XSzbk0lihIaOHdexNmMhUZYmTGrVl4CixXg5A0ls819CQ4aoaZFKvTV9+nSys7PZunUrer2e2NhYysvLgZplgLVaLWVlvw8rndmn1WqvqCRwQ6MeR7zGOJ2S6ZuO8caS/dgcDkb1zGJL0Wf8dKqQO5rdQDe24CxaSGTk7TRv9nf0+iurtqcodcHHx4eioiIACgoKCA0NRa/Xs2rVKo4dO+bh6OoXleSvIQcyi5gwawfbjufTuTmYI+exPHsjrQOTGN8sEWPRfIzmZiS1+ZYA/06eDldRLigoKIju3bvTpk0bOnXqxL59++jYsSMpKSkkJalJAWdTpYavAeU2B++tOsQHaw7jbRD073KQtTlfIhDc06wXrawrkc5yYmMeITb2EbUMn3JJDaHUcGOlSg0rNWw8nMuzc3ZyNKeEG1Js5Ju/ZtmpvfSI6MQo3yI0pXPw8etIy6RX8PZu7ulwFUVxMZXkG6n80gpeW7SPGVtOEB2oZWT/X1mVMRN/4c/4VgOJLF6EtsJA8xb/JDLyNoRQK0EqSmOkknwjI6VkwY4MXlqwm9OlNoZ3LWKf/TNWpKczPLYv/fT7oWguISFDSEx8AYMh1NMhK4riRirJNyIn8kp5ft4uVu/Ppk20oFOnlaw6tYw43xheat0b38JFGHRhtEj+P0JCruZBZkVRGhqV5BsBu8PJtA2p/O+yAwghGd0rg00F0ziZVcLdzQfSUW7EWbiPqKi7aRb/NDqdxdMhK4pSR1SSb+B2nSxg4uyd7DxZwPVJEm3ILJZlbyEluA13hBgwFM/F7N2CpOT38fNL8XS4iqLUMZXkG6jSCjtvrTjIJ+uP4m/WMrrfPlZnfo2hwMDjSYNpXrYUUWonLv4Zmja9H41G7+mQFUXxADWlogFacyCbQW+t5cO1R7ghpYyo1h+wLGMa3SM68FJ8GM1KZuPvdx1dOi+qmveuErzSuHi6nnxqaipt2rS5rGM/+OADvvjii4seM23aNB577LGrjudiVE++AckptvLPH/Ywd3s6saEaRvTbxKqMuYSIEJ5tdQNhxYvRVniT2PJfhIePUvVmlDrx4oLd7EkvdGmbrSJ9mXRT6wvubyj15O12O4888ohHY1A9+Qagchm+E9zw5hoW7sxgxPWn0UZNYVXGXEbF9eO5SAgtmk9Y6FC6dV1GRMRoleCVRu3sevJ//etf6d+/P+3btyc5OZl58yrXMkpNTaVly5Y8+OCDtG7dmoEDB9YoUPb999/TuXNnEhMTWbdu3QWvtXv3bjp37kxKSgpt27bl4MGDADgcjvO23adPH5599ll69+7N22+/zeTJk5kyZUr1vvHjx1/0ugsXLqRbt27k5OS45LNSPfl6LjWnhGfn7GTD4VxSYgUhsYv5MXM1zf2b8WRcAt5FCzDookm6bhpBQY1z0QOlfrtYj9td6rKe/AcffMCTTz7JnXfeSUVFBQ6Hg8zMzIu2nZ+fz5o1awCYPHlyjfYudt05c+bw5ptvsmjRIgICAlzyWakkX0/ZHE4+XHuEd1YexEsLY/qcYEPeF6TlVHB/4mBS7GtwFu8huumDxMc9gVZr9nTIiuIR7q4n361bN1555RXS0tIYNWoUCQkJl2z7tttuu2B7F7ruqlWr2LJlC8uWLcPX1/cy7/7SVJKvh349fpqJs3ey71QRvVs7sPp/x9LM3+gUlsJtgRJd8WzMPm1omfQZPj5134tSlPrE3fXk//SnP9GlSxcWLlzIoEGD+Pjjj4mPj79o297e3hds70LXjY+P58iRIxw4cICOHc9ba+yqqCRfjxRb7UxZup/PN6YS6qNjVL9drMr8Fu9ib55qOYj40qWIMohP+AdRTf6MRqP++JRrU13Wkz9y5Ajx8fE88cQTHDlyhB07dhAfH+/SawDExMQwZcoURo4cyffff0/r1q7pwKksUU8s35PJC/N2caqwnBs7lpIqvmB5RiqDonsyxHQCiucQENSHFokvYTI18XS4iuJRdVlPfsaMGXz11Vfo9XrCw8N54YUXKCx07WyiM1q0aMH06dO55ZZbWLBgAc2aNat1m6qevIdlFZYzaf5uFu86RUK4lhat1rImYwGR3hE8ENOSoOLF6HT+JCY+T1joMDVrRqkXVD15z1H15BsIp1PyzS/HeX3xPqx2ByN75PBryWesP5XP7c0GcD3bcBb9QETErTRvPh693t/TISuK0gCpJO8Bh7KKmDh7J7+knqZjM/CNWsCKrJ9oGdiCcfEJmIrmYTTHkdR6OgEBXT0drqJcE5YuXcr48eNrbIuLi2POnDkeisg1VJKvQ1a7g/dXHeb91YcwGzSM6XuEtTlfIvMkD7cYROuKVcjiXcTEPkZszKNotWoZPkWpK4MGDWLQoEGeDsPlVJKvI5uOVC7Ddzi7hP5t7RRavmbpqT10C+/AaP9SdCVz8PFtR1LSK1gsLTwdrqIojYRK8m5WUGrj9SV7+WbzCZoEaBnZfzurMr7Hr8yPZ1oOJKpkMZpyPc0TX6JJkzvUMnyKoriUSvJuIqVk0c5TTJq/m7wSKzd1KeKg83NWpKcxNKY3AwyHoHguwSEDSUychNEQ7umQFUVphFS30Q1O5pfxwOdb+MvX2wjxtzG4zypWF76CVggmt+rNAOcSDI5CkpPfp23yVJXgFcWN1q1bR+vWrUlJSWHjxo0sWrTokudcSSnh+k4leRdyOCWfrj/KgDfXsOFwDqN7naI4+FU2Zq7kruYDGReSj3/REpo0uZOuXZcSGtL4vuRRlPpm+vTpjBs3ju3bt7N///7LSvKNiRqucZE96YVMnL2D39IK6Joo0YfNYVn2ZpKDWnJXmDeGormYvBNIavtf/P06eDpcRXGdxRPg1E7XthmeDENev+DukpISbr31VtLS0nA4HDz//PMEBwczbtw47HY7nTp1YurUqXz55Zd89913LF26lGXLlvHTTz9RVlbG+vXrmThxInv37uXw4cOcPHmSEydO8Pe//50HH3ywxrWmTZvGli1bePfddwEYNmwY48aNo2fPntx///1s2bIFIQT33Xcff/3rX137ObiASvK1VG5z8NaKg3y07gh+Jg1j+h1kTdZ0tAVa/tJiEC2sy6Gkgtj4p4lp+iAajZenQ1aUBm/JkiVERkaycOFCoLJ+TZs2bVi5ciWJiYncfffdTJ06laeeeor169czbNgwxowZc07Cnjx5Mjt27ODnn3+mpKSEdu3aMXTo0MuKYfv27Zw8eZJdu3YBleWF6yOV5Gth/cEcnpu7k2O5pQxMsZJtnM7SjAP0juzCzT6nEaVz8PXvQsukVzCb4zwdrqK4x0V63O6SnJzMuHHjGD9+PMOGDcPX15e4uDgSExMBuOeee3jvvfd46qmnLtnWiBEjMJlMmEwm+vbty+bNm6tLCF/MmaqRjz/+OEOHDmXgwIG1vCv3qNWYvBDi30KIfUKIHUKIOUII/7P2TRRCHBJC7BdCNKrB57ySCp7+bjt3fbIJISoY0f8XNlW8SEHFaSa0uoFRup/QVxylZdIbtG83XSV4RXGxxMREtm7dSnJyMhMnTqxeDepq/LEe1B/f63Q6nE5n9fszZYwDAgL47bff6NOnD++99x4PPPDAVcfgTrX94nU50EZK2RY4AEwEEEK0Am4HWgODgfeFENpaXsvjpJTM+TWNG95cw/zt6dx8fQG6mCmsSp/N8JjePNdES3jRfEJDB9Gt6zIiI8eogmKK4gbp6emYzWbuuusuxo0bx4YNG0hNTeXQoUMAfPnll/Tu3fuc884uUXzGvHnzKC8vJzc3l9WrV9OpU6ca+2NjY9m+fTtOp5MTJ06wefNmAHJycnA6nYwePZqXX36Zbdu2uelua6dWwzVSymVnvf0ZGFP1egTwrZTSChwVQhwCOgMba3M9TzqeW8pzc3ey7mAOyTGCTvHLWHlqJc384ni8VQ8sRQsx6pvQ4rpPCA7q4+lwFaVR27lzJ8888wwajQa9Xs/UqVMpKCjglltuqf7i9XwLaPft25fXX3+dlJQUJk6cCEDnzp0ZOnQox48f5/nnnycyMrLGik3du3cnLi6O5ORk2rRpQ/v27QE4efIkY8eOre7lv/baa+6/8avgslLDQogFwAwp5VdCiHeBn6WUX1Xt+wRYLKWceZ7zHgIeAmjatGkHVxf8ry27w8kn64/ynxUH0GlgcLfjbDj9OVa7lbua9SPFsQ6nLZfo6HuJj3sKne7CK8IoSmPRWEoNT548GYvFwrhx4zwdymVzealhIcQK4HxP6zwnpZxXdcxzgB2Yfua08xx/3v9NpJQfAh9CZT35S8VTl3ak5TNh1k72ZBTSs6UTR9D3LM38lQ6hbbk9SIu+eDbeltYkXfcRvr7Jng5XURTlHJdM8lLKGy62XwhxDzAM6C9//7EgDYg+67AoIP1qg6xrJVY7/7vsANM2HCXYR8vo/ntZdeprTEUmnmo5mPjSJVDqJL75BKKjxqpl+BSlgZo8ebKnQ3C7WmUnIcRgYDzQW0pZetau+cDXQog3gUggAdhcm2vVlVX7svjH3F2kF5QxpEM5J7RfsCz9CAOie3CjOR1RPBv/wJ4ktXgZkyn60g0qiqJ4UG27oO8CBmB51SySn6WUj0gpdwshvgP2UDmM8xcppaOW13KrrKJyXlqwhx92ZNAsTMNN/TawKn0+Ed7h/KNVf0KKF6Oz+ZLY6j+Ehd2kZs0oitIg1HZ2TfOL7HsFeKU27dcFKSXfbTnBKwv3Um5zMLJ7LtvLprEmI5db4vvTQ/MbsmgB4eGjSUiYiF4f4OmQFUVRLts1PZh8OLuYZ2fvZNPRPNrFQWDThazIWkeLgASejkvAXLQAoymGpJQvCQy83tPhKoqiXLFrMslX2J18sOYw7/54CIMexvRNZV3OF6TlOnkocTBtbKuQxbuJifkfYmMfQ6s1ejpkRVGUq3LNlRrekprH0HfW8ebyA1zfqoKWHb5k6akPSA5K4pWE5rQqm43F3JROnebRrNk4leAVpYG7mnryf7R69WqGDRt2Wce+8MILrFix4qLHTJ48mSlTplxxHFfjmunJF5bbeGPxPqZvOk6kv5ZR/XewKuM7fEp8+FvLgTQtWYKmXEuzxElENbmTRlCFQVHqxBub32Bf3j6XtpkUmMT4zuNd0taZevJjx46trkJ54403uqTtP3I4HLz00ktuaftqNfqevJSSJbsyuOF/1/DN5uMM61KCX8I7LE//mgHR1zO5qQ/RxXMJCupB1y5LiI66WyV4RannSkpKGDp0KNdddx1t2rRhxowZrFy5knbt2pGcnMx9992H1Wrl448/5rvvvuOll17ijjvu4IUXXmDGjBmkpKQwY8aM87a9Zs0aUlJSSElJoV27dtW1boqLixkzZgxJSUnceeednHksKDY2lpdeeokePXrw/fffc++99zJz5szqfZMmTaJ9+/YkJyezb9+5/xl+9NFHDBkyhLKyMrd8Vo26J59RUMYL83azfE8mLSI1dO60ljUZi4j2ieLF1n3xK1yM3hlMizbvERIySE2LVJSr4Koe95VwVT3585kyZQrvvfce3bt3p7i4GKOxcsj2119/Zffu3URGRtK9e3d++uknevToAYDRaGT9+vXVsZ0tODiYbdu28f777zNlyhQ+/vjj6n3vvvsuy5YtY+7cuRgMBpd+Rmc0yp68wyn5YmMqA95cy7qDWYzqmUlZ6Gv8dGoZf2o+gGdCi/ErXEiTyNvo2mUpoaGDVYJXlAYkOTmZFStWMH78eNatW0dqauo59eTXrl17VW13796dp59+mnfeeYf8/Hx0usq+cOfOnYmKikKj0ZCSklKjiNltt912wfZGjRoFQIcOHWqc8+WXX7J48WJmzZrltgQPjbAnv+9UIRNn7+TX4/l0TpCYI+azPHsjrYOS+HNoAsbieRjNzWnZZgb+/uet56MoSj13pp78okWLmDhxoksX7JgwYQJDhw5l0aJFdO3atfpL1LMTsVarxW63V7/39r5wYcIz5/3xnDZt2rB9+3bS0tKIi3PfmhONJsmX2xz898eD/N+aI/iYNIzpd4g12V+hOa3h0RaDSbKuQJaUExv3FLExD6HRuO9/TkVR3Cs9PZ3AwEDuuusuLBYLH3zwQXU9+ebNm19RPfk/Onz4MMnJySQnJ7Nx40b27duHv7+/y++hXbt2/M///A/Dhw9n6dKlREZGuvwa0EiS/J70Qh6dvpXU3FJuSKkgzzSdpRn76RnRmZt9C9CWzsbXrxNJSa/g7d3M0+EqilJLrqonf75hlrfeeotVq1ah1Wpp1aoVQ4YMYeNG9yyF0aNHD6ZMmcLQoUNZvnw5wcHBLr+Gy+rJu0LHjh3lli1brvi87CIr932+jmaJG1h1ahaBxgDuj21PZMlitFoDzZtNIDLyVoRolF9BKEqdayz15Bsil9eTbwiyKw5TEf5vVmakMyK2L/30+5DF8wgJvZHEhBcwGEI8HaKiKIpHNIokH+EdQZg5hEdimuNbtAgvXTgt2n5ISHB/T4emKEo99dlnn/H222/X2Na9e3fee+89D0XkHo0iyWsrTvCQ7xGsRVlER91DfPxf0eksng5LURo1KWWDnno8duxYxo4d6+kwrsjVDK83iiRvMjXF2zuB5LZT8fO9ztPhKEqjZzQayc3NJSgoqEEn+oZESklubm71w1mXq1Ekeb3en3btPvd0GIpyzYiKiiItLY3s7GxPh3JNMRqNREVFXdE5jSLJK4pSt/R6vVsf4FFcR80pVBRFacRUklcURWnEVJJXFEVpxFSSVxRFacRUklcURWnEVJJXFEVpxFSSVxRFacTqVRVKIUQ2cMzTcVylYCDH00F4gLrva4u67/opRkp53kqM9SrJN2RCiC0XKvXZmKn7vrao+2541HCNoihKI6aSvKIoSiOmkrzrfOjpADxE3fe1Rd13A6PG5BVFURox1ZNXFEVpxFSSVxRFacRUkq8FIcS/hRD7hBA7hBBzhBD+Z+2bKIQ4JITYL4QY5MEwXU4IcYsQYrcQwimE6PiHfY32vgGEEIOr7u2QEGKCp+NxJyHEp0KILCHErrO2BQohlgshDlb9HuDJGF1NCBEthFglhNhb9Xf8yartDfa+VZKvneVAGyllW+AAMBFACNEKuB1oDQwG3hdCaD0WpevtAkYBa8/e2Njvu+pe3gOGAK2AO6ruubGaRuWf49kmACullAnAyqr3jYkd+JuUsiXQFfhL1Z9xg71vleRrQUq5TEppr3r7M3BmXa4RwLdSSquU8ihwCOjsiRjdQUq5V0q5/zy7GvV9U3kvh6SUR6SUFcC3VN5zoySlXAvk/WHzCODMWpufAzfXZUzuJqXMkFJuq3pdBOwFmtCA71slede5D1hc9boJcOKsfWlV2xq7xn7fjf3+LkeYlDIDKhMiEOrheNxGCBELtAM20YDvW63xeglCiBVA+Hl2PSelnFd1zHNU/pg3/cxp5zm+Qc1VvZz7Pt9p59nWoO77Ehr7/SlVhBAWYBbwlJSyUIjz/dE3DCrJX4KU8oaL7RdC3AMMA/rL3x86SAOizzosCkh3T4Tucan7voAGf9+X0Njv73JkCiEipJQZQogIIMvTAbmaEEJPZYKfLqWcXbW5wd63Gq6pBSHEYGA8MFxKWXrWrvnA7UIIgxAiDkgANnsixjrW2O/7FyBBCBEnhPCi8kvm+R6Oqa7NB+6pen0PcKGf6hokUdll/wTYK6V886xdDfa+1ROvtSCEOAQYgNyqTT9LKR+p2vccleP0dip/5Ft8/lYaHiHESOC/QAiQD2yXUg6q2tdo7xtACHEj8BagBT6VUr7i2YjcRwjxDdCHyjK7mcAkYC7wHdAUOA7cIqX845ezDZYQogewDtgJOKs2P0vluHyDvG+V5BVFURoxNVyjKIrSiKkkryiK0oipJK8oitKIqSSvKIrSiKkkryiK0oipJK8oitKIqSSvKIrSiP0/kgWIpFQttyQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def elu(x, alpha=1):\n",
    "    return max(0, x) + min(0, alpha*(np.exp(x) - 1))\n",
    "def leaky_relu(x, negative_slope=1e-2):\n",
    "    return max(0, x) + negative_slope*min(0,x)\n",
    "def p_relu(x, a=0.25):\n",
    "    return max(0, x) + a*min(0,x)\n",
    "def relu(x):\n",
    "    return max(0, x)\n",
    "def relu_6(x):\n",
    "    return min(max(0, x), 6)\n",
    "def selu(x):\n",
    "    return 1.0507009873554804934193349852946*(max(0, x)+min(0,1.0507009873554804934193349852946*(np.exp(x)-1)))\n",
    "def celu(x, alpha=0.25):\n",
    "    return max(0, x) + min(0,alpha*(np.exp(x/alpha))-1)\n",
    "def sigmoid(x):\n",
    "    return 1/(1+np.exp(-x))\n",
    "def log_sigmoid(x):\n",
    "    return np.log(1/(1+np.exp(-x)))\n",
    "def tanh(x):\n",
    "    return (np.e**x - np.e**-x)/(np.e**x + np.e**-x)\n",
    "def tanh_shrink(x):\n",
    "    return x - (np.e**x - np.e**-x)/(np.e**x + np.e**-x)\n",
    "def softplus(x, beta=1):\n",
    "    return (1/beta)*np.log(1+np.exp(beta*x))\n",
    "def soft_shrink(x, lambd=0.5):\n",
    "    if x>lambd:\n",
    "        return x-lambd\n",
    "    elif x<-lambd:\n",
    "        return x+lambd\n",
    "    else:\n",
    "        return 0\n",
    "x = [i-25 for i in range(50)]\n",
    "#画出图像\n",
    "plt.plot(x, [elu(i) for i in x],label='relu')\n",
    "plt.plot(x, [leaky_relu(i) for i in x],label='leaky_relu')\n",
    "plt.plot(x, [p_relu(i) for i in x],label='p_relu')\n",
    "plt.plot(x, [relu(i) for i in x],label='relu')\n",
    "plt.plot(x, [relu_6(i) for i in x],label='relu_6')\n",
    "plt.plot(x, [selu(i) for i in x],label='selu')\n",
    "plt.plot(x, [celu(i) for i in x],label='celu')\n",
    "plt.plot(x, [sigmoid(i) for i in x],label='sigmoid')\n",
    "plt.plot(x, [log_sigmoid(i) for i in x],label='log_sigmoid')\n",
    "plt.plot(x, [tanh(i) for i in x],label='tanh')\n",
    "plt.plot(x, [tanh_shrink(i) for i in x],label='tanh_shrink')\n",
    "plt.plot(x, [softplus(i) for i in x],label='softplus')\n",
    "plt.plot(x, [soft_shrink(i) for i in x],label='soft_shrink')\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bb148a8",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# PyTorch卷积层原理和使用"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74d4d115",
   "metadata": {
    "hidden": true
   },
   "source": [
    "$$params = OutChannels * (InChannels*KernelSize_h*KernelSize_w) + bias$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9ccf382f",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 32, 32, 32]             832\n",
      "================================================================\n",
      "Total params: 832\n",
      "Trainable params: 832\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 0.25\n",
      "Params size (MB): 0.00\n",
      "Estimated Total Size (MB): 0.26\n",
      "----------------------------------------------------------------\n",
      "None\n",
      "参数量公式：out_channels*(in_channels*kernel_size_h*kernel_size_w)+out_channels=32*(1*5*5)+32=832\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "\n",
    "class MyConvNet(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyConvNet, self).__init__()\n",
    "        self.conv = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=5,stride=1,padding=2)\n",
    "\n",
    "    # forward 定义前向传播\n",
    "    def forward(self, X):\n",
    "        return self.conv(X)\n",
    "\n",
    "model = MyConvNet()\n",
    "model = model.cuda()\n",
    "\n",
    "\n",
    "print(summary(model, input_size=(1, 32, 32), batch_size=-1))\n",
    "print('参数量公式：out_channels*(in_channels*kernel_size_h*kernel_size_w)+out_channels={}*({}*{}*{})+{}={}'.format(32,1,5,5,32,32*(1*5*5)+32))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f33c58f",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# PyTorch常见的损失函数和优化器使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8c2bf02e",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最后的loss：nan\n",
      "最后的loss：nan\n",
      "最后的loss：0.0003276001079939306\n"
     ]
    }
   ],
   "source": [
    "# 制作数据\n",
    "def gen_data():\n",
    "    # 生成数据\n",
    "    def f(x):\n",
    "        return 10 * x + 4\n",
    "\n",
    "    x = np.linspace(0, 10, 1000).astype(np.float32)+ np.random.uniform(0,1,size=1)[0]\n",
    "    y = f(x).astype(np.float32)\n",
    "\n",
    "    x = x.reshape(1000, 1)\n",
    "    y = y.reshape(1000, 1)\n",
    "    return x,y\n",
    "\n",
    "# 转换与初始化参数\n",
    "def trans_tensor(x, y):\n",
    "    # 转换张量\n",
    "    x_train = torch.from_numpy(x)\n",
    "    y_train = torch.from_numpy(y)\n",
    "\n",
    "    # 初始化参数\n",
    "    w = Variable(torch.randn(1), requires_grad=True)\n",
    "    b = Variable(torch.zeros(1), requires_grad=True)\n",
    "\n",
    "    # 构建线性回归模型\n",
    "    x_train = Variable(x_train)\n",
    "    y_train = Variable(y_train)\n",
    "    return x_train, y_train, w, b\n",
    "\n",
    "\n",
    "# 定义网络模型\n",
    "class NNLinearNet(nn.Module):\n",
    "    def __init__(self, n_feature):\n",
    "        super(NNLinearNet, self).__init__()\n",
    "        self.linear = nn.Linear(n_feature, 1)\n",
    "\n",
    "    # forward 定义前向传播\n",
    "    def forward(self, X):\n",
    "        y = self.linear(X)\n",
    "        return y\n",
    "\n",
    "\n",
    "def train():\n",
    "    x, y = gen_data()\n",
    "    x_train, y_train, w, b = trans_tensor(x, y)\n",
    "\n",
    "    # 将训练数据的特征和标签组合\n",
    "    dataset = Data.TensorDataset(x_train, y_train)\n",
    "    # 随机读取小批量\n",
    "    data_iter = Data.DataLoader(dataset, 10, shuffle=True)\n",
    "\n",
    "    learning_rate_list = [0.1, 0.5, 0.01]\n",
    "    for learning_rate in learning_rate_list:\n",
    "        # 定义训练批次\n",
    "        epochs = 10\n",
    "        model = NNLinearNet(len(x[0]))\n",
    "        loss = nn.MSELoss()\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "        for epoch in range(epochs):\n",
    "            for X, y in data_iter:\n",
    "                output = model(X)\n",
    "                l = loss(output, y.view(-1, 1))\n",
    "                optimizer.zero_grad()\n",
    "                l.backward()\n",
    "                optimizer.step()\n",
    "        print('最后的loss：{}'.format(l.item()))\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb6c0e3b",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# PyTorch池化层和归一化层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60e15c25",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "'''\n",
    "    N: batch\n",
    "    C: channel\n",
    "    H: high\n",
    "    W: weight\n",
    "    输入维度 (N, C, H_in, W_in)\n",
    "    输出维度 (N, C, H_out, W_out)\n",
    "'''\n",
    "max_pool = nn.MaxPool2d(2)\n",
    "mean_pool = nn.AvgPool2d(3)\n",
    "\n",
    "input = torch.randn(20, 16, 30, 24)\n",
    "max_output = max_pool(input)\n",
    "mean_output = mean_pool(input)\n",
    "print('输入数据的维度: {}'.format(input.shape))\n",
    "print('最大池化的维度: {}'.format(max_output.shape))\n",
    "print('平均池化的维度: {}'.format(mean_output.shape))\n",
    "\n",
    "\n",
    "'''\n",
    "    input = N * H * W\n",
    "'''\n",
    "batch_norm = nn.BatchNorm2d(16)\n",
    "print('batch方向做归一化: {}'.format(batch_norm(input).shape))\n",
    "\n",
    "'''\n",
    "    input = C * H * W\n",
    "'''\n",
    "layer_norm = nn.LayerNorm([16,30,24])\n",
    "print('channel方向做归一化: {}'.format(layer_norm(input).shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1516e9ca",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# 使用PyTorch搭建VGG网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c4885f16",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 224, 224]           1,728\n",
      "       BatchNorm2d-2         [-1, 64, 224, 224]             128\n",
      "              ReLU-3         [-1, 64, 224, 224]               0\n",
      "            Conv2d-4        [-1, 128, 112, 112]          73,728\n",
      "       BatchNorm2d-5        [-1, 128, 112, 112]             256\n",
      "              ReLU-6        [-1, 128, 112, 112]               0\n",
      "            Conv2d-7          [-1, 256, 56, 56]         294,912\n",
      "       BatchNorm2d-8          [-1, 256, 56, 56]             512\n",
      "              ReLU-9          [-1, 256, 56, 56]               0\n",
      "           Conv2d-10          [-1, 256, 56, 56]         589,824\n",
      "      BatchNorm2d-11          [-1, 256, 56, 56]             512\n",
      "             ReLU-12          [-1, 256, 56, 56]               0\n",
      "           Conv2d-13          [-1, 512, 28, 28]       1,179,648\n",
      "      BatchNorm2d-14          [-1, 512, 28, 28]           1,024\n",
      "             ReLU-15          [-1, 512, 28, 28]               0\n",
      "           Conv2d-16          [-1, 512, 28, 28]       2,359,296\n",
      "      BatchNorm2d-17          [-1, 512, 28, 28]           1,024\n",
      "             ReLU-18          [-1, 512, 28, 28]               0\n",
      "           Conv2d-19          [-1, 512, 14, 14]       2,359,296\n",
      "      BatchNorm2d-20          [-1, 512, 14, 14]           1,024\n",
      "             ReLU-21          [-1, 512, 14, 14]               0\n",
      "           Conv2d-22          [-1, 512, 14, 14]       2,359,296\n",
      "      BatchNorm2d-23          [-1, 512, 14, 14]           1,024\n",
      "             ReLU-24          [-1, 512, 14, 14]               0\n",
      "           Linear-25                 [-1, 4096]     102,764,544\n",
      "      BatchNorm1d-26                 [-1, 4096]           8,192\n",
      "           Linear-27                 [-1, 4096]      16,781,312\n",
      "      BatchNorm1d-28                 [-1, 4096]           8,192\n",
      "           Linear-29                 [-1, 1000]       4,097,000\n",
      "================================================================\n",
      "Total params: 132,882,472\n",
      "Trainable params: 132,882,472\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.57\n",
      "Forward/backward pass size (MB): 170.10\n",
      "Params size (MB): 506.91\n",
      "Estimated Total Size (MB): 677.58\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "# @File  : 8.使用PyTorch搭建VGG网络.py\n",
    "# @Author: Richard Chiming Xu\n",
    "# @Date  : 2021/11/15\n",
    "# @Desc  :\n",
    "\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchsummary import summary\n",
    "\n",
    "class VGG(nn.Module):\n",
    "    def __init__(self, arch: object, num_classes=1000) -> object:\n",
    "        super(VGG, self).__init__()\n",
    "        self.in_channels = 3\n",
    "        self.conv3_64 = self.__make_layer(64, arch[0])\n",
    "        self.conv3_128 = self.__make_layer(128, arch[1])\n",
    "        self.conv3_256 = self.__make_layer(256, arch[2])\n",
    "        self.conv3_512a = self.__make_layer(512, arch[3])\n",
    "        self.conv3_512b = self.__make_layer(512, arch[4])\n",
    "        self.fc1 = nn.Linear(7*7*512, 4096)\n",
    "        self.bn1 = nn.BatchNorm1d(4096)\n",
    "        self.bn2 = nn.BatchNorm1d(4096)\n",
    "        self.fc2 = nn.Linear(4096, 4096)\n",
    "        self.fc3 = nn.Linear(4096, num_classes)\n",
    "\n",
    "    def __make_layer(self, channels, num):\n",
    "        layers = []\n",
    "        for i in range(num):\n",
    "            layers.append(nn.Conv2d(self.in_channels, channels, 3, stride=1, padding=1, bias=False))  # same padding\n",
    "            layers.append(nn.BatchNorm2d(channels))\n",
    "            layers.append(nn.ReLU())\n",
    "            # 将上一层的out_channels替换成下一层的in_channels\n",
    "            self.in_channels = channels\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv3_64(x)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = self.conv3_128(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = self.conv3_256(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = self.conv3_512a(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = self.conv3_512b(out)\n",
    "        out = F.max_pool2d(out, 2)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        out = self.bn1(out)\n",
    "        out = F.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = F.relu(out)\n",
    "        return F.softmax(self.fc3(out))\n",
    "\n",
    "\n",
    "def VGG_11():\n",
    "    return VGG([1, 1, 2, 2, 2], num_classes=1000)\n",
    "\n",
    "def VGG_13():\n",
    "    return VGG([1, 1, 2, 2, 2], num_classes=1000)\n",
    "\n",
    "def VGG_16():\n",
    "    return VGG([2, 2, 3, 3, 3], num_classes=1000)\n",
    "\n",
    "def VGG_19():\n",
    "    return VGG([2, 2, 4, 4, 4], num_classes=1000)\n",
    "\n",
    "net = VGG_11()\n",
    "# net = VGG_13()\n",
    "# net = VGG_16()\n",
    "# net = VGG_19()\n",
    "net = net.cuda()\n",
    "summary(net, (3, 224, 224))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cfa41c0",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# 使用PyTorch搭建ResNet网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f07348cc",
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 112, 112]           9,408\n",
      "       BatchNorm2d-2         [-1, 64, 112, 112]             128\n",
      "         MaxPool2d-3           [-1, 64, 56, 56]               0\n",
      "            Conv2d-4           [-1, 64, 56, 56]          36,864\n",
      "       BatchNorm2d-5           [-1, 64, 56, 56]             128\n",
      "              ReLU-6           [-1, 64, 56, 56]               0\n",
      "         BN_Conv2d-7           [-1, 64, 56, 56]               0\n",
      "            Conv2d-8           [-1, 64, 56, 56]          36,864\n",
      "       BatchNorm2d-9           [-1, 64, 56, 56]             128\n",
      "        BN_Conv2d-10           [-1, 64, 56, 56]               0\n",
      "       BasicBlock-11           [-1, 64, 56, 56]               0\n",
      "           Conv2d-12           [-1, 64, 56, 56]          36,864\n",
      "      BatchNorm2d-13           [-1, 64, 56, 56]             128\n",
      "             ReLU-14           [-1, 64, 56, 56]               0\n",
      "        BN_Conv2d-15           [-1, 64, 56, 56]               0\n",
      "           Conv2d-16           [-1, 64, 56, 56]          36,864\n",
      "      BatchNorm2d-17           [-1, 64, 56, 56]             128\n",
      "        BN_Conv2d-18           [-1, 64, 56, 56]               0\n",
      "       BasicBlock-19           [-1, 64, 56, 56]               0\n",
      "           Conv2d-20          [-1, 128, 28, 28]          73,728\n",
      "      BatchNorm2d-21          [-1, 128, 28, 28]             256\n",
      "             ReLU-22          [-1, 128, 28, 28]               0\n",
      "        BN_Conv2d-23          [-1, 128, 28, 28]               0\n",
      "           Conv2d-24          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-25          [-1, 128, 28, 28]             256\n",
      "        BN_Conv2d-26          [-1, 128, 28, 28]               0\n",
      "           Conv2d-27          [-1, 128, 28, 28]           8,192\n",
      "      BatchNorm2d-28          [-1, 128, 28, 28]             256\n",
      "       BasicBlock-29          [-1, 128, 28, 28]               0\n",
      "           Conv2d-30          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-31          [-1, 128, 28, 28]             256\n",
      "             ReLU-32          [-1, 128, 28, 28]               0\n",
      "        BN_Conv2d-33          [-1, 128, 28, 28]               0\n",
      "           Conv2d-34          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-35          [-1, 128, 28, 28]             256\n",
      "        BN_Conv2d-36          [-1, 128, 28, 28]               0\n",
      "       BasicBlock-37          [-1, 128, 28, 28]               0\n",
      "           Conv2d-38          [-1, 256, 14, 14]         294,912\n",
      "      BatchNorm2d-39          [-1, 256, 14, 14]             512\n",
      "             ReLU-40          [-1, 256, 14, 14]               0\n",
      "        BN_Conv2d-41          [-1, 256, 14, 14]               0\n",
      "           Conv2d-42          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-43          [-1, 256, 14, 14]             512\n",
      "        BN_Conv2d-44          [-1, 256, 14, 14]               0\n",
      "           Conv2d-45          [-1, 256, 14, 14]          32,768\n",
      "      BatchNorm2d-46          [-1, 256, 14, 14]             512\n",
      "       BasicBlock-47          [-1, 256, 14, 14]               0\n",
      "           Conv2d-48          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-49          [-1, 256, 14, 14]             512\n",
      "             ReLU-50          [-1, 256, 14, 14]               0\n",
      "        BN_Conv2d-51          [-1, 256, 14, 14]               0\n",
      "           Conv2d-52          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-53          [-1, 256, 14, 14]             512\n",
      "        BN_Conv2d-54          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-55          [-1, 256, 14, 14]               0\n",
      "           Conv2d-56            [-1, 512, 7, 7]       1,179,648\n",
      "      BatchNorm2d-57            [-1, 512, 7, 7]           1,024\n",
      "             ReLU-58            [-1, 512, 7, 7]               0\n",
      "        BN_Conv2d-59            [-1, 512, 7, 7]               0\n",
      "           Conv2d-60            [-1, 512, 7, 7]       2,359,296\n",
      "      BatchNorm2d-61            [-1, 512, 7, 7]           1,024\n",
      "        BN_Conv2d-62            [-1, 512, 7, 7]               0\n",
      "           Conv2d-63            [-1, 512, 7, 7]         131,072\n",
      "      BatchNorm2d-64            [-1, 512, 7, 7]           1,024\n",
      "       BasicBlock-65            [-1, 512, 7, 7]               0\n",
      "           Conv2d-66            [-1, 512, 7, 7]       2,359,296\n",
      "      BatchNorm2d-67            [-1, 512, 7, 7]           1,024\n",
      "             ReLU-68            [-1, 512, 7, 7]               0\n",
      "        BN_Conv2d-69            [-1, 512, 7, 7]               0\n",
      "           Conv2d-70            [-1, 512, 7, 7]       2,359,296\n",
      "      BatchNorm2d-71            [-1, 512, 7, 7]           1,024\n",
      "        BN_Conv2d-72            [-1, 512, 7, 7]               0\n",
      "       BasicBlock-73            [-1, 512, 7, 7]               0\n",
      "        AvgPool2d-74            [-1, 512, 1, 1]               0\n",
      "           Linear-75                 [-1, 1000]         513,000\n",
      "================================================================\n",
      "Total params: 11,689,512\n",
      "Trainable params: 11,689,512\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.57\n",
      "Forward/backward pass size (MB): 62.41\n",
      "Params size (MB): 44.59\n",
      "Estimated Total Size (MB): 107.58\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "# @File  : 9.使用PyTorch搭建ResNet网络.py\n",
    "# @Author: Richard Chiming Xu\n",
    "# @Date  : 2021/11/15\n",
    "# @Desc  :\n",
    "from telnetlib import SE\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchsummary import summary\n",
    "\n",
    "# 打包BN卷积层\n",
    "class BN_Conv2d(nn.Module):\n",
    "    def __init__(self, in_channels: object, out_channels: object, kernel_size: object, stride: object, padding: object,\n",
    "                 dilation=1, groups=1, bias=False, activation=True) -> object:\n",
    "        super(BN_Conv2d, self).__init__()\n",
    "\n",
    "        # 这种结果是默认每一个卷积后先BN再相加\n",
    "        layers = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,\n",
    "                            padding=padding, dilation=dilation, groups=groups, bias=bias),\n",
    "                  nn.BatchNorm2d(out_channels)]\n",
    "        if activation:\n",
    "            layers.append(nn.ReLU(inplace=True))\n",
    "        self.seq = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.seq(x)\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    \"\"\"\n",
    "    basic building block for ResNet-18, ResNet-34\n",
    "    \"\"\"\n",
    "    message = \"basic\"\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, strides, is_se=False):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.is_se = is_se\n",
    "        # 创建两个卷积层,尺寸来自论文中\n",
    "        self.conv1 = BN_Conv2d(in_channels, out_channels, 3, stride=strides, padding=1, bias=False)  # same padding\n",
    "        self.conv2 = BN_Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False, activation=False)\n",
    "        if self.is_se:\n",
    "            self.se = SE(out_channels, 16)\n",
    "\n",
    "        # fit input with residual output\n",
    "        self.short_cut = nn.Sequential()\n",
    "        if strides is not 1:\n",
    "            self.short_cut = nn.Sequential(\n",
    "                nn.Conv2d(in_channels, out_channels, 1, stride=strides, padding=0, bias=False),\n",
    "                nn.BatchNorm2d(out_channels)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.conv2(out)\n",
    "        if self.is_se:\n",
    "            coefficient = self.se(out)\n",
    "            out = out * coefficient\n",
    "        out = out + self.short_cut(x)\n",
    "        return F.relu(out)\n",
    "\n",
    "class BottleNeck(nn.Module):\n",
    "    \"\"\"\n",
    "    BottleNeck block for RestNet-50, ResNet-101, ResNet-152\n",
    "    \"\"\"\n",
    "    message = \"bottleneck\"\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, strides, is_se=False):\n",
    "        super(BottleNeck, self).__init__()\n",
    "        self.is_se = is_se\n",
    "        # 同理，尺寸来自论文中\n",
    "        self.conv1 = BN_Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)  # same padding\n",
    "        self.conv2 = BN_Conv2d(out_channels, out_channels, 3, stride=strides, padding=1, bias=False)\n",
    "        self.conv3 = BN_Conv2d(out_channels, out_channels * 4, 1, stride=1, padding=0, bias=False, activation=False)\n",
    "        if self.is_se:\n",
    "            self.se = SE(out_channels * 4, 16)\n",
    "\n",
    "        # fit input with residual output\n",
    "        self.shortcut = nn.Sequential(\n",
    "            nn.Conv2d(in_channels, out_channels * 4, 1, stride=strides, padding=0, bias=False),\n",
    "            nn.BatchNorm2d(out_channels * 4)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.conv2(out)\n",
    "        out = self.conv3(out)\n",
    "        if self.is_se:\n",
    "            coefficient = self.se(out)\n",
    "            out = out * coefficient\n",
    "        out = out + self.shortcut(x)\n",
    "        return F.relu(out)\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "    \"\"\"\n",
    "    building ResNet_34\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, block: object, groups: object, num_classes=1000) -> object:\n",
    "        super(ResNet, self).__init__()\n",
    "        self.channels = 64  # out channels from the first convolutional layer\n",
    "        self.block = block\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, self.channels, 7, stride=2, padding=3, bias=False)\n",
    "        self.bn = nn.BatchNorm2d(self.channels)\n",
    "        self.pool1 = nn.MaxPool2d(3, 2, 1)\n",
    "        self.conv2_x = self._make_conv_x(channels=64, blocks=groups[0], strides=1, index=2)\n",
    "        self.conv3_x = self._make_conv_x(channels=128, blocks=groups[1], strides=2, index=3)\n",
    "        self.conv4_x = self._make_conv_x(channels=256, blocks=groups[2], strides=2, index=4)\n",
    "        self.conv5_x = self._make_conv_x(channels=512, blocks=groups[3], strides=2, index=5)\n",
    "        self.pool2 = nn.AvgPool2d(7)\n",
    "        patches = 512 if self.block.message == \"basic\" else 512 * 4\n",
    "        self.fc = nn.Linear(patches, num_classes)  # for 224 * 224 input size\n",
    "\n",
    "    def _make_conv_x(self, channels, blocks, strides, index):\n",
    "        \"\"\"\n",
    "        making convolutional group\n",
    "        :param channels: output channels of the conv-group\n",
    "        :param blocks: number of blocks in the conv-group\n",
    "        :param strides: strides\n",
    "        :return: conv-group\n",
    "        \"\"\"\n",
    "        list_strides = [strides] + [1] * (blocks - 1)  # In conv_x groups, the first strides is 2, the others are ones.\n",
    "        conv_x = nn.Sequential()\n",
    "        for i in range(len(list_strides)):\n",
    "            layer_name = str(\"block_%d_%d\" % (index, i))  # when use add_module, the name should be difference.\n",
    "            conv_x.add_module(layer_name, self.block(self.channels, channels, list_strides[i]))\n",
    "            # 如果是basic block，则不乘4。这个操作也是来源于论文\n",
    "            self.channels = channels if self.block.message == \"basic\" else channels * 4\n",
    "        return conv_x\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = F.relu(self.bn(out))\n",
    "        out = self.pool1(out)\n",
    "        out = self.conv2_x(out)\n",
    "        out = self.conv3_x(out)\n",
    "        out = self.conv4_x(out)\n",
    "        out = self.conv5_x(out)\n",
    "        out = self.pool2(out)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = F.softmax(self.fc(out))\n",
    "        return out\n",
    "\n",
    "def ResNet_18(num_classes=1000):\n",
    "    return ResNet(block=BasicBlock, groups=[2, 2, 2, 2], num_classes=num_classes)\n",
    "\n",
    "def ResNet_34(num_classes=1000):\n",
    "    return ResNet(block=BasicBlock, groups=[3, 4, 6, 3], num_classes=num_classes)\n",
    "\n",
    "def ResNet_50(num_classes=1000):\n",
    "    return ResNet(block=BottleNeck, groups=[3, 4, 6, 3], num_classes=num_classes)\n",
    "\n",
    "def ResNet_101(num_classes=1000):\n",
    "    return ResNet(block=BottleNeck, groups=[3, 4, 23, 3], num_classes=num_classes)\n",
    "\n",
    "def ResNet_152(num_classes=1000):\n",
    "    return ResNet(block=BottleNeck, groups=[3, 8, 36, 3], num_classes=num_classes)\n",
    "\n",
    "net = ResNet_18()\n",
    "# net = ResNet_34()\n",
    "# net = ResNet_50()\n",
    "# net = ResNet_101()\n",
    "# net = ResNet_152()\n",
    "net = net.cuda()\n",
    "summary(net, (3, 224, 224))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed95c4d2",
   "metadata": {},
   "source": [
    "# 使用PyTorch完成Fashion-MNIST分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "984ea1b2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\env\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\\torch\\csrc\\utils\\tensor_numpy.cpp:180.)\n",
      "  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CNN(\n",
      "  (layer1): Sequential(\n",
      "    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n",
      "    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU()\n",
      "  )\n",
      "  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (layer2): Sequential(\n",
      "    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU()\n",
      "  )\n",
      "  (layer3): Sequential(\n",
      "    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
      "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (2): ReLU()\n",
      "  )\n",
      "  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (fc1): Linear(in_features=1600, out_features=128, bias=True)\n",
      "  (fc2): Linear(in_features=128, out_features=10, bias=True)\n",
      ")\n",
      "Epoch : 1/50, Iter : 100/0,  Loss: 0.4179\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-d01fb1384a4a>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     77\u001b[0m \u001b[0mlosses\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     78\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'EPOCHS'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 79\u001b[1;33m     \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     80\u001b[0m         \u001b[0mimages\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     81\u001b[0m         \u001b[0mlabels\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    519\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 521\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    523\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[1;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    559\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    560\u001b[0m         \u001b[0mindex\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 561\u001b[1;33m         \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# may raise StopIteration\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    562\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    563\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[1;34m(self, possibly_batched_index)\u001b[0m\n\u001b[0;32m     42\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\utils\\data\\_utils\\fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     42\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     43\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 44\u001b[1;33m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     45\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     46\u001b[0m             \u001b[0mdata\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[1;34m(self, index)\u001b[0m\n\u001b[0;32m    132\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    133\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 134\u001b[1;33m             \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    135\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    136\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, img)\u001b[0m\n\u001b[0;32m     58\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     59\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransforms\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 60\u001b[1;33m             \u001b[0mimg\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     61\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     62\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1049\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m   1050\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1051\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1052\u001b[0m         \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1053\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\transforms.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, tensor)\u001b[0m\n\u001b[0;32m    219\u001b[0m             \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mNormalized\u001b[0m \u001b[0mTensor\u001b[0m \u001b[0mimage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    220\u001b[0m         \"\"\"\n\u001b[1;32m--> 221\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnormalize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstd\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minplace\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    222\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    223\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torchvision\\transforms\\functional.py\u001b[0m in \u001b[0;36mnormalize\u001b[1;34m(tensor, mean, std, inplace)\u001b[0m\n\u001b[0;32m    330\u001b[0m         \u001b[1;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'std evaluated to zero after conversion to {}, leading to division by zero.'\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    331\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mmean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 332\u001b[1;33m         \u001b[0mmean\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmean\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    333\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mstd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    334\u001b[0m         \u001b[0mstd\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mstd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "# @File  : 10.使用PyTorch完成Fashion-MNIST分类.py\n",
    "# @Author: Richard Chiming Xu\n",
    "# @Date  : 2021/11/16\n",
    "# @Desc  :\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchvision import transforms\n",
    "import torchvision\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "# 判定是否使用GPU计算\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "config = {\n",
    "    'BATCH_SIZE': 256,\n",
    "    'LEARNING_RATE': 0.01,\n",
    "    'EPOCHS': 50\n",
    "}\n",
    "\n",
    "\n",
    "# 初始化图片转换\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
    "\n",
    "# 读取数据集\n",
    "trainset = torchvision.datasets.FashionMNIST('D:/env/sample_data/', download=True, train=True, transform=transform)\n",
    "testset = torchvision.datasets.FashionMNIST('D:/env/sample_data/', download=True, train=False, transform=transform)\n",
    "\n",
    "# 转换dataloader\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['BATCH_SIZE'], shuffle=True)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=config['BATCH_SIZE'], shuffle=True)\n",
    "\n",
    "# 搭建网络\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(1, 16, kernel_size=5, padding=2),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU())  # 16, 28, 28\n",
    "        self.pool1 = nn.MaxPool2d(2)  # 16, 14, 14\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Conv2d(16, 32, kernel_size=3),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU())  # 32, 12, 12\n",
    "        self.layer3 = nn.Sequential(\n",
    "            nn.Conv2d(32, 64, kernel_size=3),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU())  # 64, 10, 10\n",
    "        self.layer4 = nn.Sequential(\n",
    "            nn.Conv2d(64, 128, kernel_size=3),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU())  # 128, 8, 8\n",
    "        self.pool2 = nn.MaxPool2d(2)  # 128, 4, 4\n",
    "        self.fc1 = nn.Linear(4 * 4 * 128, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.layer1(x)\n",
    "        out = self.pool1(out)\n",
    "\n",
    "        out = self.layer2(out)\n",
    "\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        out = self.pool2(out)\n",
    "\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        out = self.fc2(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "cnn =CNN()\n",
    "cnn.cuda(device)\n",
    "print(cnn)\n",
    "criterion = nn.CrossEntropyLoss().to(device)\n",
    "optimizer = torch.optim.Adam(cnn.parameters(), lr = config['LEARNING_RATE'])\n",
    "\n",
    "\n",
    "losses = []\n",
    "for epoch in range(config['EPOCHS']):\n",
    "    for i, (images, labels) in enumerate(trainloader):\n",
    "        images = images.float().to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()  # 梯度清零\n",
    "        outputs = cnn(images)\n",
    "        loss = criterion(outputs, labels)  # 计算损失函数\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.cpu().data.item())\n",
    "        if (i + 1) % 100 == 0:\n",
    "            print('Epoch : %d/%d, Iter : %d/%d,  Loss: %.4f' % (\n",
    "            epoch + 1, config['EPOCHS'], i + 1, len(trainloader) // config['BATCH_SIZE'], loss.data.item()))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a1628fc",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# 使用PyTorch完成人脸关键点检测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "253d71ce",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "# @File  : 11.使用PyTorch完成人脸关键点检测.py\n",
    "# @Author: Richard Chiming Xu\n",
    "# @Date  : 2021/11/17\n",
    "# @Desc  :\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from PIL import Image, ImageDraw\n",
    "from torch import nn\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import time\n",
    "\n",
    "train_df = pd.read_csv('data/face/train.csv')\n",
    "train_df = train_df.fillna(48)\n",
    "train_img = np.load('data/face/train.npy')\n",
    "test_img = np.load('data/face/test.npy')\n",
    "\n",
    "\n",
    "\n",
    "torch.backends.cudnn.benchmark = False\n",
    "# 判定是否使用GPU计算\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "config = {\n",
    "    'BATCH_SIZE': 256,\n",
    "    'LEARNING_RATE': 0.01,\n",
    "    'EPOCHS': 50\n",
    "}\n",
    "\n",
    "\n",
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img, keypoint, transform=None):\n",
    "        self.img = img\n",
    "        self.transform = transform\n",
    "        self.keypoint = keypoint\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.fromarray(self.img[:, :, index]).convert('RGB')\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        return img, self.keypoint[index] / 96.0\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.img.shape[-1]\n",
    "\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_img[:, :, :-500], train_df.values[:-500],\n",
    "                    transforms.Compose([\n",
    "                        transforms.ToTensor(),\n",
    "        ])\n",
    "    ),\n",
    "    batch_size=10, shuffle=True\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_img[:, :, -500:], train_df.values[-500:],\n",
    "                    transforms.Compose([\n",
    "                        transforms.ToTensor(),\n",
    "        ])\n",
    "    ),\n",
    "    batch_size=10, shuffle=False\n",
    ")\n",
    "\n",
    "# 搭建网络\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.layer1 = nn.Sequential(\n",
    "            nn.Conv2d(3, 16, kernel_size=5, padding=2),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU())  # 16, 98, 98\n",
    "        self.pool1 = nn.MaxPool2d(2)  # 16, 49, 49\n",
    "        self.layer2 = nn.Sequential(\n",
    "            nn.Conv2d(16, 32, kernel_size=3),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.ReLU())  # 32, 46, 46\n",
    "        self.layer3 = nn.Sequential(\n",
    "            nn.Conv2d(32, 64, kernel_size=3),\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.ReLU())  # 64, 44, 44\n",
    "        self.pool2 = nn.MaxPool2d(2)  # 64, 22, 22\n",
    "        self.layer4 = nn.Sequential(\n",
    "            nn.Conv2d(64, 128, kernel_size=3),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU())  # 128, 20, 20\n",
    "        self.pool3 = nn.MaxPool2d(2) # 128, 10, 10\n",
    "        self.fc1 = nn.Linear(10 * 10 * 128, 128)\n",
    "        self.fc2 = nn.Linear(128, 8)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x.shape = 10,3,96,96\n",
    "\n",
    "        out = self.layer1(x) # out.shape = 10,16,96,96\n",
    "        out = self.pool1(out) # out.shape = 10,16,48,48\n",
    "        out = self.layer2(out) #out.shape = 10,32,46,46\n",
    "\n",
    "\n",
    "        out = self.layer3(out) #out.shape = 10,64,44,44\n",
    "        out = self.pool2(out) #out.shape = 10,64,22,22\n",
    "\n",
    "        out = self.layer4(out) #out.shape = 10,128,20,20\n",
    "        out = self.pool3(out) #out.shape = 10,128,10,10\n",
    "\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc1(out)\n",
    "        out = self.fc2(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    model.train()\n",
    "\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True).float()\n",
    "        target = target.cuda(non_blocking=True).float()\n",
    "\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "        optimizer.zero_grad()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 200 == 0:\n",
    "            print(loss.item())\n",
    "\n",
    "\n",
    "def validate(val_loader, model):\n",
    "    model.eval()\n",
    "\n",
    "    val_feats = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda().float()\n",
    "            target = target.cuda().float()\n",
    "            output = model(input)\n",
    "            val_feats.append(output.data.cpu().numpy())\n",
    "    return val_feats\n",
    "\n",
    "\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "# 自己的CNN（调整优化卷积的经验？）\n",
    "# model = CNN()\n",
    "# resnet与训练\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "model.fc = nn.Linear(model.fc.in_features,8,bias=False)\n",
    "\n",
    "\n",
    "model.cuda(device)\n",
    "criterion = nn.MSELoss().cuda()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 0.001)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.85)\n",
    "best_acc = 0.0\n",
    "\n",
    "for epoch in range(5):\n",
    "    print('Epoch: ', epoch)\n",
    "\n",
    "    train(train_loader, model, criterion, optimizer, epoch)\n",
    "\n",
    "    val_feats = validate(val_loader, model)\n",
    "    scheduler.step()\n",
    "\n",
    "    val_feats = np.vstack(val_feats) * 96\n",
    "    print('Val', mean_absolute_error(val_feats, train_df.values[-500:]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d5d3d8f",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "# 使用PyTorch搭建对抗生成网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "725a8b2c",
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "#!/usr/bin/env python\n",
    "# -*- coding: utf-8 -*-\n",
    "# @File  : 12.使用PyTorch搭建对抗生成网络.py\n",
    "# @Author: Richard Chiming Xu\n",
    "# @Date  : 2021/11/18\n",
    "# @Desc  :\n",
    "\n",
    "import random\n",
    "import torch.nn.parallel\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from GanNetwork import *\n",
    "import os\n",
    "\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "from torch.utils.data.dataset import Dataset\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "\n",
    "config = {\n",
    "    'BATCH_SIZE': 64,  # batch 大小\n",
    "    'LEARNING_RATE': 0.0003,  # 学习率\n",
    "    'EPOCHS': 100,  # 训练次数\n",
    "    'IMAGE_SIZE': 64,  # 压缩图片的尺寸\n",
    "    'NC': 3,  # 鉴别器输入，对应图片3原色\n",
    "    'NZ': 100,  # z向量大小，生成器的输入\n",
    "    'NGF': 64,  # 生成器基础大小\n",
    "    'NDF': 64,  # 生成器基础大小\n",
    "    'BETA': 0.5,  # 优化器超参\n",
    "    'RANDOM_SEED': 2021\n",
    "}\n",
    "\n",
    "# 设定随机因子\n",
    "print(\"Random Seed: \", config['RANDOM_SEED'])\n",
    "random.seed(config['RANDOM_SEED'])\n",
    "torch.manual_seed(config['RANDOM_SEED'])\n",
    "# 检查GPU\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "'''\n",
    "    读取数据\n",
    "'''\n",
    "train_img = np.load('data/face/train.npy')\n",
    "train_df = pd.read_csv('data/face/train.csv')\n",
    "torch.backends.cudnn.benchmark = False\n",
    "# 判定是否使用GPU计算\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "# 自定义dataset\n",
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img, batch_size, transform=None):\n",
    "        self.img = img\n",
    "        self.transform = transform\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.fromarray(self.img[:, :, index]).convert('RGB')\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        # 由于不需要标签，模拟默认全部为0\n",
    "        return img, torch.zeros(self.batch_size)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.img.shape[-1]\n",
    "\n",
    "\n",
    "# 生成dataloader\n",
    "dataloader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_img[:, :, :-500], config['BATCH_SIZE'], transforms.Compose([\n",
    "        transforms.Resize(config['IMAGE_SIZE']),\n",
    "        transforms.CenterCrop(config['IMAGE_SIZE']),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "    ])\n",
    "                  ), batch_size=config['BATCH_SIZE'], shuffle=True\n",
    ")\n",
    "\n",
    "'''\n",
    "    创建模型，定义优化器和损失函数\n",
    "'''\n",
    "# 创建生成器\n",
    "netG = Generator(config['NC'], config['NZ'], config['NGF']).to(device)\n",
    "netG.apply(weights_init)\n",
    "print(netG)\n",
    "\n",
    "# 创建鉴别器\n",
    "netD = Discriminator(config['NC'], config['NDF']).to(device)\n",
    "netD.apply(weights_init)\n",
    "print(netD)\n",
    "\n",
    "# 初始化损失函数\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "# 定义生成器的噪声\n",
    "fixed_noise = torch.randn(64, config['NZ'], 1, 1, device=device)\n",
    "\n",
    "# 初始化正确与错误的结果，提供给鉴别器\n",
    "real_label = 1\n",
    "fake_label = 0\n",
    "\n",
    "# 定义优化器\n",
    "optimizerD = optim.Adam(netD.parameters(), lr=config['LEARNING_RATE'], betas=(config['BETA'], 0.999))\n",
    "optimizerG = optim.Adam(netG.parameters(), lr=config['LEARNING_RATE'], betas=(config['BETA'], 0.999))\n",
    "\n",
    "# 缓存数据变量\n",
    "img_list = []\n",
    "G_losses = []\n",
    "D_losses = []\n",
    "iters = 0\n",
    "\n",
    "'''\n",
    "    训练模型\n",
    "'''\n",
    "print(\"初始化完毕看是训练...\")\n",
    "for epoch in range(config['EPOCHS']):\n",
    "    import time\n",
    "\n",
    "    start = time.time()\n",
    "    # 便利dataloader\n",
    "    for i, data in enumerate(dataloader, 0):\n",
    "\n",
    "        ############################\n",
    "        # (1) 更新鉴别器网络: maximize log(D(x)) + log(1 - D(G(z)))\n",
    "        ###########################\n",
    "        ## 训练所有真是数据\n",
    "        netD.zero_grad()\n",
    "        real_cpu = data[0].to(device)\n",
    "        b_size = real_cpu.size(0)\n",
    "        # 为真实数据填充真标签\n",
    "        label = torch.full((b_size,), real_label, device=device, dtype=torch.float32)\n",
    "        # 传入网络\n",
    "        output = netD(real_cpu).view(-1)\n",
    "        # 计算loss\n",
    "        errD_real = criterion(output, label)\n",
    "        # 反向传播计算梯度\n",
    "        errD_real.backward()\n",
    "        D_x = output.mean().item()\n",
    "\n",
    "        ## 训练所有虚假数据，这里很具大小随机生成噪声数据\n",
    "        noise = torch.randn(b_size, config['NZ'], 1, 1, device=device, dtype=torch.float32)\n",
    "        # 传入生成器\n",
    "        fake = netG(noise)\n",
    "        # 填充假标签\n",
    "        label.fill_(fake_label)\n",
    "        # 基于生成器生成的数据传入鉴别器\n",
    "        output = netD(fake.detach()).view(-1)\n",
    "        # 假标签与鉴别器结果计算loss\n",
    "        errD_fake = criterion(output, label)\n",
    "        # 反向传播计算梯度\n",
    "        errD_fake.backward()\n",
    "        D_G_z1 = output.mean().item()\n",
    "        # 叠加增加loss\n",
    "        errD = errD_real + errD_fake\n",
    "        # 更新鉴别器网络\n",
    "        optimizerD.step()\n",
    "\n",
    "        ############################\n",
    "        # (2) 更新生成器: maximize log(D(G(z)))\n",
    "        ###########################\n",
    "        netG.zero_grad()\n",
    "        # 重新填充真实标签\n",
    "        label.fill_(real_label)\n",
    "        # 将生成器通过噪音生成的图片传入鉴别器\n",
    "        output = netD(fake).view(-1)\n",
    "        # 假设都为真实数据，与鉴别器的结果做对比求loss\n",
    "        errG = criterion(output, label)\n",
    "        # 计算梯度\n",
    "        errG.backward()\n",
    "        D_G_z2 = output.mean().item()\n",
    "        # 更新完网络\n",
    "        optimizerG.step()\n",
    "\n",
    "        # 打印\n",
    "        if i % 50 == 0:\n",
    "            print('[%d/%d][%d/%d]\\tLoss_D: %.4f\\tLoss_G: %.4f\\tD(x): %.4f\\tD(G(z)): %.4f / %.4f'\n",
    "                  % (epoch, config['EPOCHS'], i, len(dataloader),\n",
    "                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))\n",
    "\n",
    "        # 记录历史数据\n",
    "        G_losses.append(errG.item())\n",
    "        D_losses.append(errD.item())\n",
    "\n",
    "        # 定时打印生成\n",
    "        if (iters % 20 == 0) or ((epoch == config['EPOCHS'] - 1) and (i == len(dataloader) - 1)):\n",
    "            with torch.no_grad():\n",
    "                fake = netG(fixed_noise).detach().cpu()\n",
    "\n",
    "            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))\n",
    "            i = vutils.make_grid(fake, padding=2, normalize=True)\n",
    "            fig = plt.figure(figsize=(8, 8))\n",
    "            plt.imshow(np.transpose(i, (1, 2, 0)))\n",
    "            plt.axis('off')  # 关闭坐标轴\n",
    "            plt.savefig(\"./out/%d_%d.png\" % (epoch, iters))\n",
    "            plt.close(fig)\n",
    "        iters += 1\n",
    "    print('time:', time.time() - start)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
